Casting Semantics #412
Replies: 3 comments
-
cc: @peri044 |
Beta Was this translation helpful? Give feedback.
-
Some additional detail: Explicit Precision ControlCase 1 : Layer CastingPytorch graph:
Torchscript graph:
Current limitations:a) In the above graph, the In TRT, if we set the builder precision to FP32, but provide FP16 input to x(fp16) -> Reformat (fp32) -> (Conv1) -> Reformat(fp16) ->output (fp16) This doesn't satisfy our requirement that b) If we set the layer precision to FP16 (conv1->setPrecision(fp16)), but the builder precision to FP32, the compilation fails with this error
c) If we set the layer precision to FP16 and builder precision to FP16, then this makes all the other layers in network to run in FP16 as well if possible. This is an overkill. For ex: The input graph is x (fp16) -> Conv1 -> max_pool -> output Here both One solution is to set builder precision to FP16, conv1 precision to FP16 and max_pool precision to FP32. We also need need strict_types to be true here. TRTorch handling:enabled_precision flag This can be a set of precisions that are enabled during inference. If strict_types is If strict_type is The builder precision should exactly take on values from this For example, if
Here are possible cases and our solutions : a) c) d) e) When should we enable setPrecision for any layer ?When a user has So in all converters, macro which does the above functionality.
pytorch : x (fp16) -> max_pool - runs in fp16 tensorrt : no explicit signal. a) b) c) Tell users to use Default behavior :
If Case 2: Tensor CastingTensor casting is expressed in Torchscript representation using The primary use case for supporting Pytorch graph :
Torchscript representation:
Since the output type of
In TRT source code, it checks the following
Solution:So whenever we find any layer output type to be FP16/INT8 using Possible cases for the above network: a) if b) if
c) if d) if e) if Case 3: Setting input datatypes in TRTorchCurrently in TRTorch, the input datatype is set based on builder precision. This is not sufficient when the network has multiple inputs and one of the datatypes is either INT (maybe a shape input) or FP16. ONNX handling :In ONNX graphs, the input datatypes are inferred from the nodes in the graph. The problem with Torchscript graphs is the Solution:The solution would be like how TRT used to manually configure inputs in UFF (deprecated). We would have a
|
Beta Was this translation helpful? Give feedback.
-
Brainstorming the new import torch
import trtorch
...
'''
trt_mod = trtorch.compile(ts_mod, {
"input_shapes": [[1,2,2,2], {
"min": (1,2,2,2),
"opt": (3,2,2,2),
"max": (121, 20, 20, 20)
}]
})
trt_mod = trtorch.compile(ts_mod, {
"input_shapes": [[1,2,2,2], {
"min": (1,2,2,2),
"opt": (3,2,2,2),
"max": (121, 20, 20, 20)
}]
"input_dtypes": [torch.int32, torch.float32]
"input_tensor_formats": [torch.contigous_format, torch.channel_last]
})
trt_mod = trtorch.compile(ts_mod, {
"inputs" : [trtorch.Input((1,2,2,2,), dtype=torch.int32, format=torch.contiguous_format),
trtorch.Input({
"min": (1,2,2,2),
"opt": (3,2,2,2),
"max": (121, 20, 20, 20)
})],
})'''
# We selected this option
trt_mod = trtorch.compile(ts_mod, {
"inputs" : [trtorch.Input(shape=(1,2,2,2,), dtype=torch.int32, format=torch.contiguous_format),
trtorch.Input(min_shape=(1,2,2,2), opt_shape=(3,2,2,2), max_shape=(121, 20, 20, 20), format=torch.channel_last)],
})
trt_mod = trtorch.compile(ts_mod, {"inputs": [trtorch.Input(shape=(1,2,2,2)), trtorch.Input(shape=(3,2,2,2))]})
class Input:
InputRange(Tuple() or Dict())
def __init__(self, shape, dtype=torch.float32, format=torch.contiguous_format):
# We selected this option:
def __init__(self, shape=None, min_shape=None, opt_shape=None, max_shape=None, dtype=torch.float32, format=torch.contiguous_format):
if not shape and (not min_shape and not opt_shape and not max_shape):
throw Error()
# Example on how to use tensors as example input for shape, type and format inference
trt_mod = trtorch.compile(ts_mod, {
"inputs" : [torch.Tensor((1,2,2,2), dtype=torch.int32, memory_format=torch.contiguous_format)],
}) struct Input {
Input(trtorch::InputRange shape);
Input(std::vector<uint64_t> shape);
Input(std::vector<uint64_t> min_shape, std::vector<uint64_t> opt_shape, std::vector<uint64_t> max_shape);
Input(trtorch::InputRange shape, trtorch::DataType dtype=trtorch::DataType::kFloat32, trtorch::Format format=trtorch::Format::kNCHW);
Input(std::vector<uint64_t> min_shape, std::vector<uint64_t> opt_shape, std::vector<uint64_t> max_shape dtype=trtorch::DataType::kFloat32, trtorch::Format format=trtorch::Format::kNCHW);
}
auto in1 = trtorch::Input((1,2,2,2))
auto spec = CompileSpec({Input})
trt_mod = trtorch::compile(ts_mod, )
auto in_shape1 = trtorch::Input(trtorch::InputRange({1,2,2,2},{1,2,2,2,},{.2.2.2.2.} ))
auto in_shape1 = trtorch::Input({1,2,2,2,});
auto in_shape1 = trtorch::Input(/*min_shape=*/{1,2,2,2},/*opt_shape=*/{1,2,2,2,},/*max_shape=*/{.2.2.2.2.});
auto in_shape1 = trtorch::Input({1,2,2,2,}, torch::kChar);
auto in_shape1 = trtorch::Input(/*min_shape=*/{1,2,2,2},/*opt_shape=*/{1,2,2,2,},/*max_shape=*/{.2.2.2.2.}, torch::kChar);
auto spec = trtorch::CompileSpec({in_shape1});
spec.enabled_types.push(trtorch::DataType::kFloat16);
auto spec = trtorch::compile(ts_mod, {{trtorch::Input()}});
auto spec = trtorch::compile(ts_mod, {{{1,1,1,1}}}); |
Beta Was this translation helpful? Give feedback.
-
Casting
There are some cases where users may want the control the specific precision layers and tensors exist at. PyTorch has APIs to do this by casting tensors and the weights of modules. We need to determine a way to map PyTorch casting semantics to TensorRT in a way that is understandable.
Tensor Casting
Tensor casting means casting a tensor that is a input or product of an operation.
PyTorch API
TorchScript Representation
Casting operations like this in the course of a TorchScript graph are represented by
aten::to
Handling Casting in TRTorch
For tensor casting TensorRT provides the
IdentityLayer
. Presumably all we need to do is have a converter that takesaten::to
and maps it to this layer, then applies the rightsetLayerPrecision
on this layer.Layer Casting
Layer casting means casting the weight tensors of a module
PyTorch API
TorchScript Representation
Casting operations on modules directly effect the owned tensors of the module and nothing else. When you run
model.half()
you are really just casting the weight tensors to FP16 in the same way as you do above manually. Therefore there is not any TorchScript representation other than the fact that the weights are in FP16. (Need to fully verify this)Handling Casting in TRTorch
Since the only signal we have about explicit layer precisions in TRTorch is the type of the weights, we therefore need converters to inspect and respect the layer precision of the weights that they convert. This would augment the converter contract to add this responsibility. Ideally we could find some way to automate it or make it a one liner. I think we have Torch to TensorRT datatype casting supported already so something like
layer->setLayerPrecision(weight_tensor.dtype);
Might be sufficient.
Necessary Precision API changes
I think the current op precision API is misleading and confusing. It makes people believe that networks will run in one type instead of what it is really doing which is enabling new precisions to be selected. We should change the API to be something closer to a set of additional enabled data types, while still clearly conveying to people that FP32 will always be an option unless they specify
strict_types
With this new API, we would need to add the following check in the ConversionCtx, we would also need to change the behavior that adds FP16 whenever people enable INT8
We should also take this chance to add fuller coverage for TensorRT APIs including adding Tensor Layer options and allowing users to specify input data type
Beta Was this translation helpful? Give feedback.
All reactions