Skip to content

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Oct 16, 2025

Fixes #3865

@meta-cla meta-cla bot added the cla signed label Oct 16, 2025
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Oct 16, 2025
@github-actions github-actions bot requested a review from narendasan October 16, 2025 19:46
@apbose apbose force-pushed the abose/torchTRT_accelerate_bug_fix branch from bee0f1d to 88659a1 Compare October 16, 2025 21:14
@apbose apbose force-pushed the abose/torchTRT_accelerate_bug_fix branch from 6e386c1 to e6fc22b Compare October 17, 2025 20:47
for i, each_input in enumerate(input):
if isinstance(each_input, torch.Tensor) and each_input.numel() == 0:
logger.warning(
f"Warning: empty tensor in cat input {i}, replacing with zeros"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make this warning much more specific? Print information like the current node, if you can where in the graph it comes from etc. Because users will not understand what you mean by this. Also where is the replacing with zeros?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if this is caught by the validator then should this be an error? Will conversion fail or can we just ignore it?

Copy link
Collaborator Author

@apbose apbose Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out the error. I was earlier replacing with zeros, but later changed to continue since replacing with zeros is not required. I will change the warning comment.

The difference between this and the validator is that, if the empty tensor is a torch.Tensor, we can handle it in the converter.

Whereas if the empty tensor comes as an ITensor input to the converter, TensorRT complains. (I was trying to implement it earlier via replacing it with zeros, but that still leads to the error [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed. To point the difference,

This will pass

def test_cat_with_empty_tensor(self, _, dim):
       # Handle empty tensor in concat
       class Cat(nn.Module):
           def forward(self, x):
               y = torch.empty(0, 2, 3, device="cuda")
               return torch.ops.aten.cat.default((x, y), dim)

       inputs = [
           torch.randn(1, 2, 3, device="cuda"),
       ]
       self.run_test(Cat(), inputs)

This will fail

 def test_cat_with_empty_tensor(self, _, dim):
        # Handle empty tensor in concat
        class Cat(nn.Module):
            def forward(self, x, y):
                return torch.ops.aten.cat.default((x, y), dim)

        inputs = [
            torch.randn(1, 2, 3, device="cuda"),
            y = torch.empty(0, 2, 3, device="cuda")
        ]
        self.run_test(Cat(), inputs)

return input_tensors, dim


def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont really understand this condition. So if we have a TRT ITensor that has a 0 in any dimension then we should break the graph? I dont think at validation time any of these ITensors will be available. Since validation is run prior to paritioning

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be checking for empty PyTorch tensors?

Copy link
Collaborator Author

@apbose apbose Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes ideally. The validation would be based on the ITensor shape. Yes should use the meta data

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then this won't distinguish between ITensor and torch Tensor case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🐛 [Bug]Torch distributed data parallel accelerate GPT2 example failing

2 participants