diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 8c0706539c..c4e18d9119 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -392,7 +392,9 @@ def index_dtype_validator( @dynamo_tensorrt_converter( - torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator + torch.ops.aten.index.Tensor, + capability_validator=index_dtype_validator, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index 8e21f945dc..fc4a70b1ff 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -168,7 +168,31 @@ def forward(self, input): dtype=torch.float32, ), ] - self.run_test_with_dynamic_shape(TestModule(), input_specs) + self.run_test_with_dynamic_shape( + TestModule(), input_specs, use_dynamo_tracer=True + ) + + +class TestIndexDynamicInputNonDynamicIndexConverter(DispatchTestCase): + def test_index_input_non_dynamic_index_dynamic(self): + class TestIndexWithRuntimeIndex(torch.nn.Module): + def forward(self, x): + mask = x > 0 + idx = torch.nonzero(mask, as_tuple=True) + return torch.ops.aten.index.Tensor(x, idx) + + input_specs = [ + Input( + min_shape=(2, 2), + opt_shape=(2, 2), + max_shape=(8, 8), + dtype=torch.float32, + ), + ] + # In this case the index args[1] gets itself converted to a List of TRTTensors with use_dynamo_tracer=True + self.run_test_with_dynamic_shape( + TestIndexWithRuntimeIndex(), input_specs, use_dynamo_tracer=True + ) if __name__ == "__main__":