convert.converter.convert_torch2trt
convert_with_torch2trt
@register_conversion("torch2trt")
def convert_with_torch2trt(model: nn.Module,
input_args,
export_path: str,
int8: bool = False,
fp16: bool = False,
use_dla: bool = False,
**kwargs)
Convert PyTorch model to TensorRT engine using torch2trt.
Arguments:
model
nn.Module - The PyTorch model to be converted.input_args
- Input arguments used for conversion.export_path
str - Path to save the TensorRT engine.int8
bool, optional - Whether to use int8 precision. Defaults to False.fp16
bool, optional - Whether to use fp16 precision. Defaults to False.use_dla
bool, optional - Whether to use DLA. Defaults to False.**kwargs
- Additional arguments for torch2trt conversion.
Returns:
None
- The function saves the TensorRT engine to the specified path.