Skip to main content
Version: v2509

inference.inference

torch_dtype_from_trt

def torch_dtype_from_trt(dtype) -> torch.dtype

Convert TensorRT data types to PyTorch data types.

Arguments:

  • dtype trt.DataType - TensorRT data type to convert

Returns:

  • torch.dtype - Corresponding PyTorch data type

Raises:

  • TypeError - If dtype is not supported by PyTorch

torch_device_from_trt

def torch_device_from_trt(device) -> torch.device

Convert TensorRT device locations to PyTorch device objects.

Arguments:

  • device trt.TensorLocation - TensorRT device location to convert

Returns:

  • torch.device - Corresponding PyTorch device (cuda or cpu)

Raises:

  • TypeError - If device location is not supported

TRTInferenceEngine Objects

class TRTInferenceEngine(nn.Module)

TensorRT inference engine wrapper for PyTorch modules.

Handles dynamic input/output bindings and device-specific execution for TensorRT-optimized models.

Attributes:

  • context - TensorRT execution context
  • input_names - List of input tensor names
  • output_names - List of output tensor names
  • output_bindings - Tensor bindings for outputs
  • variadic_output_binding_idx - Dictionary of variadic output bindings
  • variadic_input_binding_idx - List of variadic input binding indices
  • class_name - Optional custom class name for identification
  • tensorrt_version_tuple - TensorRT version as a tuple

__init__

def __init__(engine_path: str, class_name: Union[str, None] = None)

Initialize TensorRT inference engine.

Arguments:

  • engine_path str - Path to TensorRT engine file
  • class_name Union[str, None], optional - Custom class name for identification

load_engine

@staticmethod
def load_engine(engine_path: str)

Load TensorRT engine from file and extract binding information.

Arguments:

  • engine_path str - Path to TensorRT engine file

Returns:

  • tuple - (execution context, input names list, output names list)

Raises:

  • AssertionError - If engine file fails to deserialize

forward_deprecate

def forward_deprecate(*args)

Execute inference with TensorRT engine using execute_async_v2.

This method is for compatibility with TensorRT versions before 10.0.

forward

def forward(*args, **kwargs)

Execute inference with TensorRT engine.

Handles variadic input/output bindings and executes the engine asynchronously. Uses execute_v3 for TensorRT 10.0+ and falls back to execute_async_v2 for older versions.

Returns:

torch.Tensor or list: Inference output(s) as PyTorch tensor(s)

load_runtime_modules

def load_runtime_modules(model: nn.Module, rt_config: dict, engine_path: str)

Recursively load RunTime modules into PyTorch model.

Supports hierarchical model structures by recursively applying RunTime module loading to submodules.

Arguments:

  • model nn.Module - Base PyTorch model to modify
  • rt_config dict - RunTime configuration
  • engine_path str - Base path for engine files

Returns:

  • nn.Module - Modified model with RunTime modules loaded

load_runtime_module

def load_runtime_module(model: nn.Module, rt_mode: str, engine_path: str)

Load RunTime engine into PyTorch module.

Arguments:

  • model nn.Module - Target PyTorch module
  • rt_mode str - RunTime mode ("onnx" or "torch2trt")
  • engine_path str - Path to engine file

Returns:

  • nn.Module - Modified module with RunTime engine loaded