Replies: 1 comment
-
Phase 1.Use convert_module_to_engine to produce standalone TRT engines from PyTorch Modules, and then out of library, figure out how to populate programmatically refit settings and refit the engine. Phase 2. [MVP]Now support the same workflow but inside of the Torch-TRT context, and start to prototype the UX for end users (only needs to support fully supported models, but should work end to end) Phase 3.Support advanced features of Torch-TensorRT, e.g. dynamic shape, fallback to pytorch, use library libraries like huggingface and torch.compile. ExtensionsDemos of interesting workflows (LoRAs, Offsite compilation, caching etc.) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Model Weight Refit
TL;DR
TensorRT supports updating engine weights after compilation via the
nvinfer1::IRefitter
class, referenced here in C++, and here in Python. This could be a beneficial feature to bring into Torch-TensorRT, specifically the FX path, since models which are pre-compiled and saved can be easily refit to new training weights, so long as the model architecture is unchanged.Goals and Usecases
Model weight refit will assist greatly in reducing time spent compiling models with Torch-TensorRT, since models would only need to be compiled once per architecture, and subsequent weight updates can be propagated into the compiled model post-compilation, without the overhead of recompiling. This also enables pre-compilation of a model architecture prior to training, which could allow for inference timing estimates ahead of sometimes-lengthy model training.
Proposed APIs / UX
Users of this feature (in FX), would compile their model via the FX API, for example with:
torch_tensorrt.fx.compile(...)
, then save their model. Then, at a later time when loading the compiled model, if the weights have updated from their original values, the use could call the Model Weight Refit function to refit the stored weights.Example Workflow
See below for a sample workflow of compiling and saving a model via FX2TRT
TensorRT/examples/fx/fx2trt_example.py
Lines 23 to 146 in deda87b
The additional step, as per the proposed API would be to call:
This function would parse the input weights dictionary, determine which of those to assign to which submodule from the splitter (assuming the model was not fully compiled in TRT), and assign the weights accordingly, using the TRT Python API for TRT-accelerated modules, and the Torch API for non-accelerated modules.
weights
could potentially be the output ofmodel.state_dict()
after training, or a different format.Internal Implementation
Design
A function would be added to FX2TRT, for example, as:
As mentioned above,
refit_weights
would parse the input weights, determine which of those to assign to which submodule in the compiled model, and assign the weights accordingly, respecting the module boundaries of TRT and Torch. It would likely need a few helper functions, for example:Extensions Required to Core API implementations
The existing library should not require many changes, as this add-on would simply add functionality while preserving existing core APIs.
Details specific for TorchScript Support
TorchScript Python API support is more challenging in this particular case, since the weight Tensor objects would need to be transferred from Python to C++. A similar design would function well, for example
ts_model.refit_weights(weights)
, however the FX path would make a better MVP since the implementations could stay strictly in Python, via the tensorrt Python API.Details specific for FX support
See above
Implementation Phases
Prototype - Small/Medium
model.state_dict()
for a newly-trained model contains sufficient information to update the weights of an existing modelMVP
1.4.0
- Mediumrefit_weights
function, including refitting weights with multiple TRT-accelerated submodules and multiple Torch/FX non-accelerated submodulesExtension Phase 1 [Potential] - Medium
refit_weights(...)
for TorchScript APIstate_dict
would be feasibleBeta Was this translation helpful? Give feedback.
All reactions