inference.inference
torch_dtype_from_trt
def torch_dtype_from_trt(dtype) -> torch.dtype
Convert TensorRT data types to PyTorch data types.
Arguments:
dtype
trt.DataType - TensorRT data type to convert
Returns:
torch.dtype
- Corresponding PyTorch data type
Raises:
TypeError
- If dtype is not supported by PyTorch
torch_device_from_trt
def torch_device_from_trt(device) -> torch.device
Convert TensorRT device locations to PyTorch device objects.
Arguments:
device
trt.TensorLocation - TensorRT device location to convert
Returns:
torch.device
- Corresponding PyTorch device (cuda or cpu)
Raises:
TypeError
- If device location is not supported
TRTInferenceEngine Objects
class TRTInferenceEngine(nn.Module)
TensorRT inference engine wrapper for PyTorch modules.
Handles dynamic input/output bindings and device-specific execution for TensorRT-optimized models.
Attributes:
context
- TensorRT execution contextinput_names
- List of input tensor namesoutput_names
- List of output tensor namesoutput_bindings
- Tensor bindings for outputsvariadic_output_binding_idx
- Dictionary of variadic output bindingsvariadic_input_binding_idx
- List of variadic input binding indicesclass_name
- Optional custom class name for identificationtensorrt_version_tuple
- TensorRT version as a tuple
__init__
def __init__(engine_path: str, class_name: Union[str, None] = None)
Initialize TensorRT inference engine.
Arguments:
engine_path
str - Path to TensorRT engine fileclass_name
Union[str, None], optional - Custom class name for identification
load_engine
@staticmethod
def load_engine(engine_path: str)
Load TensorRT engine from file and extract binding information.
Arguments:
engine_path
str - Path to TensorRT engine file
Returns:
tuple
- (execution context, input names list, output names list)
Raises:
AssertionError
- If engine file fails to deserialize
forward_deprecate
def forward_deprecate(*args)
Execute inference with TensorRT engine using execute_async_v2.
This method is for compatibility with TensorRT versions before 10.0.
forward
def forward(*args, **kwargs)
Execute inference with TensorRT engine.
Handles variadic input/output bindings and executes the engine
asynchronously. Uses execute_v3
for TensorRT 10.0+ and falls
back to execute_async_v2
for older versions.
Returns:
torch.Tensor or list: Inference output(s) as PyTorch tensor(s)
load_runtime_modules
def load_runtime_modules(model: nn.Module, rt_config: dict, engine_path: str)
Recursively load RunTime modules into PyTorch model.
Supports hierarchical model structures by recursively applying RunTime module loading to submodules.
Arguments:
model
nn.Module - Base PyTorch model to modifyrt_config
dict - RunTime configurationengine_path
str - Base path for engine files
Returns:
nn.Module
- Modified model with RunTime modules loaded
load_runtime_module
def load_runtime_module(model: nn.Module, rt_mode: str, engine_path: str)
Load RunTime engine into PyTorch module.
Arguments:
model
nn.Module - Target PyTorch modulert_mode
str - RunTime mode ("onnx" or "torch2trt")engine_path
str - Path to engine file
Returns:
nn.Module
- Modified module with RunTime engine loaded