diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 5ba84b09b0..2f2814c5f5 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -36,6 +36,8 @@ def constant_fold( # The constants are created on CPU to save GPU memory for TensorRT compilation. # For TRT INetwork construction the constants are moved to CPU in get_attr call. for node, constant in cf.node_replacements.items(): + if node.target == torch.ops.aten.embedding.default: + continue replace_node_with_constant( gm, node, torch.nn.Parameter(constant, requires_grad=False) ) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 1d619b6ce3..777bb32a2d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -2,7 +2,6 @@ import logging from contextlib import nullcontext -from tempfile import tempdir from typing import Any, Dict, List, Optional, Sequence, Tuple import tensorrt as trt @@ -539,7 +538,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: with tempfile.TemporaryDirectory() as tmpdir: self.cudagraph.debug_dump( - f"{tempdir}/{self.name}_cudagraph.dot" + f"{tmpdir}/{self.name}_cudagraph.dot" ) self.cudagraph.replay() # type: ignore diff --git a/tools/perf/README.md b/tools/perf/README.md index 4d4579efb4..36c85386f7 100644 --- a/tools/perf/README.md +++ b/tools/perf/README.md @@ -9,8 +9,6 @@ This is a comprehensive Python benchmark suite to run perf runs using different 5. TensorRT -Note: Please note that for ONNX models, user can convert the ONNX model to TensorRT serialized engine and then use this package. - ## Prerequisite Benchmark scripts depends on following Python packages in addition to requirements.txt packages @@ -47,6 +45,7 @@ Here are the list of `CompileSpec` options that can be provided directly to comp * `--backends` : Comma separated string of backends. Eg: torch, torch_compile, dynamo, tensorrt * `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `dynamo` or `torch_compile`, the input should be a Pytorch module (instead of a torchscript module). * `--model_torch` : Name of the PyTorch model file (optional, only necessary if `dynamo` or `torch_compile` is a chosen backend) +* `--onnx` : ONNX model file which helps bypass the step of exporting ONNX from `model_torch`. If this argument is provided, the ONNX will be directly converted to TRT engine * `--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT * `--batch_size` : Batch size * `--precision` : Comma separated list of precisions to build TensorRT engine Eg: fp32,fp16 @@ -54,6 +53,7 @@ Here are the list of `CompileSpec` options that can be provided directly to comp * `--truncate` : Truncate long and double weights in the network in Torch-TensorRT * `--is_trt_engine` : Boolean flag to be enabled if the model file provided is a TensorRT engine. * `--report` : Path of the output file where performance summary is written. +* `--optimization_level` : Builder optimization level for TensorRT (from 1 to 5, 5 is the highest optimization). Eg: diff --git a/tools/perf/perf_run.py b/tools/perf/perf_run.py index ca37316ea8..f7bc94d27d 100644 --- a/tools/perf/perf_run.py +++ b/tools/perf/perf_run.py @@ -174,8 +174,7 @@ def run_ts_trt(model, input_tensors, params, precision, batch_size): compile_settings = { "inputs": input_tensors, "enabled_precisions": {precision_to_dtype(precision)}, - "truncate_long_and_double": params.get("truncate", False), - "use_python_runtime": params.get("use_python_runtime", False), + "truncate_double": params.get("truncate", False), } if precision == "int8": @@ -274,8 +273,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size): ir="dynamo", enabled_precisions={precision_to_dtype(precision)}, min_block_size=params.get("min_block_size", 1), - debug=False, - truncate_long_and_double=params.get("truncate", False), + truncate_double=params.get("truncate", False), immutable_weights=params.get("immutable_weights", True), strip_engine_weights=params.get("strip_engine_weights", False), refit_identical_engine_weights=params.get( @@ -284,6 +282,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size): cache_built_engines=params.get("cache_built_engines", False), reuse_cached_engines=params.get("reuse_cached_engines", False), use_python_runtime=params.get("use_python_runtime", False), + optimization_level=params.get("optimization_level", 3), ) end_compile = timeit.default_timer() compile_time_s = end_compile - start_compile @@ -437,25 +436,33 @@ def run_tensorrt( precision, batch_size=1, ): - # Export an ONNX model and convert to TRT - torch.onnx.export(model.eval().cuda(), tuple(input_tensors), "./tmp.onnx") logger = trt.Logger(trt.Logger.WARNING) - builder = trt.Builder(logger) - network = builder.create_network( - 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) - ) - parser = trt.OnnxParser(network, logger) - success = parser.parse_from_file("./tmp.onnx") - if not success: - raise ValueError("ONNX conversion failed") - - config = builder.create_builder_config() - if precision == "fp16": - config.set_flag(trt.BuilderFlag.FP16) - start_compile = timeit.default_timer() - serialized_engine = builder.build_serialized_network(network, config) - end_compile = timeit.default_timer() - compile_time_s = end_compile - start_compile + compile_time_s = 0 + if params["is_trt_engine"]: + serialized_engine = model + else: + if params["onnx"]: + onnx_path = params["onnx"] + else: + onnx_path = "./onnx-trt.onnx" + torch.onnx.export(model, tuple(input_tensors), onnx_path, dynamo=True) + builder = trt.Builder(logger) + network = builder.create_network( + 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + ) + parser = trt.OnnxParser(network, logger) + success = parser.parse_from_file(onnx_path) + if not success: + raise ValueError("ONNX conversion failed") + + config = builder.create_builder_config() + if precision == "fp16": + config.set_flag(trt.BuilderFlag.FP16) + config.builder_optimization_level = params.get("optimization_level", 3) + start_compile = timeit.default_timer() + serialized_engine = builder.build_serialized_network(network, config) + end_compile = timeit.default_timer() + compile_time_s = end_compile - start_compile # Deserialize the TensorRT engine with trt.Runtime(logger) as runtime: engine = runtime.deserialize_cuda_engine(serialized_engine) @@ -463,35 +470,72 @@ def run_tensorrt( print("Running TensorRT for precision: ", precision, " batch_size : ", batch_size) iters = params.get("iterations", 20) - # Compiling the bindings - bindings = engine.num_bindings * [None] - k = 0 - for idx, _ in enumerate(bindings): - dtype = torch_dtype_from_trt(engine.get_binding_dtype(idx)) - shape = tuple(engine.get_binding_shape(idx)) - device = torch_device_from_trt(engine.get_location(idx)) - if not engine.binding_is_input(idx): - # Output bindings - output = torch.empty(size=shape, dtype=dtype, device=device) - bindings[idx] = output.data_ptr() - else: - # Input bindings - bindings[idx] = input_tensors[k].data_ptr() - k += 1 + start_time = timeit.default_timer() + # Get I/O tensor information using TensorRT 10 API + input_names = [] + output_names = [] + output_dtypes = [] + output_shapes = [] + + for i in range(engine.num_io_tensors): + tensor_name = engine.get_tensor_name(i) + tensor_mode = engine.get_tensor_mode(tensor_name) + tensor_dtype = engine.get_tensor_dtype(tensor_name) + tensor_shape = engine.get_tensor_shape(tensor_name) + + if tensor_mode == trt.TensorIOMode.INPUT: + input_names.append(tensor_name) + else: # trt.TensorIOMode.OUTPUT + output_names.append(tensor_name) + output_dtypes.append(torch_dtype_from_trt(tensor_dtype)) + output_shapes.append(tuple(tensor_shape)) + + # Create output tensors + output_tensors = [] + for i, (shape, dtype) in enumerate(zip(output_shapes, output_dtypes)): + output = torch.empty(size=shape, dtype=dtype, device="cuda") + output_tensors.append(output) timings = [] with engine.create_execution_context() as context: + # Set input tensor addresses + for i, (input_name, input_tensor) in enumerate(zip(input_names, input_tensors)): + context.set_tensor_address(input_name, input_tensor.data_ptr()) + + # Set output tensor addresses + for output_name, output_tensor in zip(output_names, output_tensors): + context.set_tensor_address(output_name, output_tensor.data_ptr()) + + # Create a dedicated stream for TensorRT execution + dedicated_stream = torch.cuda.Stream() + current_stream = torch.cuda.current_stream() + + setup_time = timeit.default_timer() + + # Warm up for i in range(WARMUP_ITER): - context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream) + # Wait for current stream to finish + dedicated_stream.wait_stream(current_stream) + context.execute_async_v3(dedicated_stream.cuda_stream) + # Wait for TensorRT stream to finish + current_stream.wait_stream(dedicated_stream) torch.cuda.synchronize() + infer_start_time = timeit.default_timer() + # Performance measurement for i in range(iters): - start_time = timeit.default_timer() - context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream) + # Wait for current stream to finish + dedicated_stream.wait_stream(current_stream) + context.execute_async_v3(dedicated_stream.cuda_stream) + # Wait for TensorRT stream to finish + current_stream.wait_stream(dedicated_stream) torch.cuda.synchronize() - end_time = timeit.default_timer() - meas_time = end_time - start_time - timings.append(meas_time) + + end_time = timeit.default_timer() + + # to compare against torch-trt dynamo apples to apples + infer_time = (end_time - infer_start_time + setup_time - start_time) / iters + timings.append(infer_time) recordStats("TensorRT", timings, precision, batch_size, compile_time_s) @@ -504,7 +548,6 @@ def run( params, precision, batch_size=1, - is_trt_engine=False, model_torch=None, ): for backend in backends: @@ -523,7 +566,7 @@ def run( print("int8 precision expects calibration cache file for inference") return False - if (model is None) and (backend in ("tensorrt", "ts_trt", "all")): + if (model is None) and (backend in ("ts_trt", "all")): warnings.warn( f"Requested backend {backend} without specifying a TorchScript Model, " + "skipping this backend" @@ -547,11 +590,10 @@ def run( batch_size, ) run_tensorrt( - model, + model_torch, input_tensors, params, precision, - is_trt_engine, batch_size, ) run_dynamo(model_torch, input_tensors, params, precision, batch_size) @@ -604,6 +646,12 @@ def run( default="", help="Name of torch model file", ) + arg_parser.add_argument( + "--onnx", + type=str, + default="", + help="ONNX model file which helps bypass the step of exporting ONNX from torchscript model. If this argument is provided, the ONNX will be directly converted to TRT engine", + ) arg_parser.add_argument( "--inputs", type=str, @@ -643,6 +691,12 @@ def run( action="store_true", help="Truncate long and double weights in the network in Torch-TensorRT", ) + arg_parser.add_argument( + "--optimization_level", + type=int, + default=3, + help="Builder optimization level for TensorRT", + ) arg_parser.add_argument( "--is_trt_engine", action="store_true", @@ -702,8 +756,13 @@ def run( # Load TorchScript model, if provided if os.path.exists(model_name): - print("Loading user provided torchscript model: ", model_name) - model = torch.jit.load(model_name).cuda().eval() + if params["is_trt_engine"]: + with open(model_name, "rb") as f: + model = f.read() + print("Loading user provided trt engine: ", model_name) + else: + print("Loading user provided torchscript model: ", model_name) + model = torch.jit.load(model_name).cuda().eval() # Load PyTorch Model, if provided if len(model_name_torch) > 0 and os.path.exists(model_name_torch): @@ -719,7 +778,9 @@ def run( ) backends = parse_backends(params["backends"]) - if ("dynamo" in backends or "torch_compile" in backends) and (model_torch is None): + if any( + backend in ["dynamo", "torch_compile", "tensorrt"] for backend in backends + ) and (model_torch is None): raise ValueError( "No Pytorch model (nn.Module) is provided for torchdynamo compilation. Please provide a pytorch model using --model_torch argument" ) @@ -746,7 +807,6 @@ def run( params, precision, batch_size, - is_trt_engine, model_torch=model_torch, ) diff --git a/tools/perf/requirements.txt b/tools/perf/requirements.txt index fcfb0b3d53..efc11a05b5 100644 --- a/tools/perf/requirements.txt +++ b/tools/perf/requirements.txt @@ -4,6 +4,5 @@ pyyaml onnx pandas transformers -diffusers==0.21.4 +diffusers timm==0.9.8 - diff --git a/tools/perf/utils.py b/tools/perf/utils.py index 5dae807892..0fd38e6447 100644 --- a/tools/perf/utils.py +++ b/tools/perf/utils.py @@ -176,6 +176,8 @@ def torch_dtype_from_trt(dtype): return torch.bool elif dtype == trt.int32: return torch.int32 + elif dtype == trt.int64: + return torch.int64 elif dtype == trt.float16: return torch.float16 elif dtype == trt.float32: