intelligence.acuirt.inference.inference
torch_dtype_from_trt
def torch_dtype_from_trt(dtype) -> torch.dtype
Convert TensorRT data types to PyTorch data types.
Arguments:
dtypetrt.DataType - TensorRT data type to convert
Returns:
torch.dtype- Corresponding PyTorch data type
Raises:
TypeError- If dtype is not supported by PyTorch
torch_device_from_trt
def torch_device_from_trt(device) -> torch.device
Convert TensorRT device locations to PyTorch device objects.
Arguments:
devicetrt.TensorLocation - TensorRT device location to convert
Returns:
torch.device- Corresponding PyTorch device (cuda or cpu)
Raises:
TypeError- If device location is not supported
TRTInferenceEngine Objects
class TRTInferenceEngine(nn.Module)
TensorRT inference engine wrapper for PyTorch modules.
Handles dynamic input/output bindings and device-specific execution for TensorRT-optimized models.
Attributes:
context- TensorRT execution contextinput_names- List of input tensor namesoutput_names- List of output tensor namesoutput_bindings- Tensor bindings for outputsvariadic_output_binding_idx- Dictionary of variadic output bindingsvariadic_input_binding_idx- List of variadic input binding indicesclass_name- Optional custom class name for identificationtensorrt_version_tuple- TensorRT version as a tuple
__init__
def __init__(engine_path: str,
class_name: Union[str, None] = None,
*,
logger: AcuiRTDefaultLogger)
Initialize TensorRT inference engine.
Arguments:
engine_pathstr - Path to TensorRT engine fileclass_nameUnion[str, None], optional - Custom class name for identificationloggerAcuiRTDefaultLogger - Logger object for tracking conversion progress
load_engine
@staticmethod
def load_engine(engine_path: str, logger: AcuiRTDefaultLogger)
Load TensorRT engine from file and extract binding information.
Arguments:
engine_pathstr - Path to TensorRT engine fileloggerAcuiRTDefaultLogger - Logger object for tracking conversion progress
Returns:
tuple- (execution context, input names list, output names list)
Raises:
AssertionError- If engine file fails to deserialize
forward_deprecate
def forward_deprecate(*args)
Execute inference with TensorRT engine using execute_async_v2.
This method is for compatibility with TensorRT versions before 10.0.
forward
def forward(*args, **kwargs)
Execute inference with TensorRT engine.
Handles variadic input/output bindings and executes the engine
asynchronously. Uses execute_v3 for TensorRT 10.0+ and falls
back to execute_async_v2 for older versions.
Returns:
torch.Tensor or list: Inference output(s) as PyTorch tensor(s)
load_runtime_modules
def load_runtime_modules(model: nn.Module,
rt_config: dict,
engine_path: str,
*,
logger: Optional[AcuiRTDefaultLogger] = None)
Recursively load RunTime modules into PyTorch model.
Supports hierarchical model structures by recursively applying RunTime module loading to submodules.
Arguments:
modelnn.Module - Base PyTorch model to modifyrt_configdict - RunTime configurationengine_pathstr - Base path for engine filesloggerOptional[AcuiRTDefaultLogger] - Logger object for tracking conversion progress. Defaults to None.
Returns:
nn.Module- Modified model with RunTime modules loaded