From 48af4fd4b44e8c188da6439b5b9969a8250bb9bc Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 8 Jul 2025 12:07:43 -0700 Subject: [PATCH 1/2] fix torch compile gpt2 error --- core/runtime/execute_engine.cpp | 19 ++++++++++++++++++- .../runtime/_PythonTorchTensorRTModule.py | 2 +- .../dynamo/runtime/_TorchTensorRTModule.py | 6 ++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 64b111750f..11547a2312 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -121,7 +121,24 @@ void setup_input_tensors( // Shape tensor inputs are casted to int64 explicitly. // Refer to // https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435 - auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt64); + at::Tensor cloned_input; + + // Check if it's a scalar tensor (0-dimensional) + if (inputs[i].dim() == 0 && inputs[i].numel() == 1) { + // It's a scalar tensor, create a proper tensor from the scalar value + int64_t scalar_value = inputs[i].item(); + LOG_DEBUG("Input " << i << " is a scalar tensor with value: " << scalar_value); + cloned_input = torch::tensor({scalar_value}, torch::kInt64); + LOG_DEBUG("cloned_input dim: " << cloned_input.dim() << " ; numel: " << cloned_input.numel()); + } else { + // It's a regular tensor + LOG_DEBUG( + "Input " << i << " is a regular tensor" + << " inputs[i]: " << inputs[i]); + cloned_input = inputs[i].clone(); + } + auto input_cpu = cloned_input.contiguous().cpu().to(torch::kInt64); + std::vector inputs_cpu_vec( input_cpu.data_ptr(), input_cpu.data_ptr() + input_cpu.numel()); inputShapeTensorValues.emplace_back(inputs_cpu_vec); diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index fc76b20141..dda5929d4b 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -430,7 +430,7 @@ def create_output_allocator(self) -> None: def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: - shape_changed = self.validate_input_shapes(inputs) + shape_changed = self.validate_input_shapes(contiguous_inputs) ( need_cudagraphs_record, can_use_pre_allocated_outputs, diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 95f1581881..1ea787df41 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -325,6 +325,12 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: for i in inputs ] + for input_tensor in input_tensors: + if not isinstance(input_tensor, torch.Tensor): + raise ValueError( + f"lan added Unsupported input type: {type(input_tensor)}" + ) + outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine( list(input_tensors), self.engine ) From a2e51fff23cfd0d4c1dd1176f2d6e5bb631c1432 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 8 Jul 2025 12:09:29 -0700 Subject: [PATCH 2/2] test --- py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 1ea787df41..95f1581881 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -325,12 +325,6 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: for i in inputs ] - for input_tensor in input_tensors: - if not isinstance(input_tensor, torch.Tensor): - raise ValueError( - f"lan added Unsupported input type: {type(input_tensor)}" - ) - outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine( list(input_tensors), self.engine )