-
Notifications
You must be signed in to change notification settings - Fork 103
Description
I'm using a GNN as both actor and critic with VMAS, and I keep on running into a runtime error during training. Notably the error only occurs when the GNN is being used as a critic; running with a GNN actor and an MLP critic works just fine.
In particular the error comes from the line in models/gnn.py:
graphs.edge_index = torch_geometric.nn.pool.radius_graph(
graphs.pos, batch=graphs.batch, r=edge_radius, loop=self_loops
)The full stack trace is below. Both graphs.pos and graphs.batch are reasonable (they have shapes [660000, 2] and [660000], respectively, and expected values). The error occurs specifically when graph.pos is a BatchedTensor (an internal class due to a vmap call in ClipPPOLoss, see stack trace). BatchedTensor has unexpected behavior with .view and .reshape, and somehow torch_geometric.nn.pool.radius_graph isn't playing well with it.
Is there a fix for this problem? Has anyone gotten a GNN critic to work? Thanks!
Full stack trace:
File "/home/gridsan/rfan/VMAS-Navigation/vmas_test.py", line 94, in <module>
experiment.run()
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/benchmarl/experiment/experiment.py", line 642, in run
self._collection_loop()
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/benchmarl/experiment/experiment.py", line 720, in _collection_loop
group_batch = self.algorithm.process_batch(group, group_batch)
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/benchmarl/algorithms/mappo.py", line 263, in process_batch
loss.value_estimator(
File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/torchrl/objectives/value/advantages.py", line 79, in new_func
return fun(self, *args, **kwargs)
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/torchrl/objectives/value/advantages.py", line 68, in new_fun
return fun(self, *args, **kwargs)
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/torchrl/objectives/value/advantages.py", line 1468, in forward
value, next_value = self._call_value_nets(
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/torchrl/objectives/value/advantages.py", line 527, in _call_value_nets
data_out = _vmap_func(
^ vmap call from loss.value_estimator.
File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/_functorch/apis.py", line 188, in wrapped
return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 281, in vmap_impl
return _flat_vmap(
File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 47, in fn
return f(*args, **kwargs)
File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 403, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/torchrl/objectives/utils.py", line 539, in decorated_module
return module(*module_args)
File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/utils.py", line 373, in wrapper
result = func(_self, tensordict, *args, **kwargs)
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 624, in forward
tensordict_exec = self._run_module(
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 570, in _run_module
tensordict = module(tensordict, **kwargs)
File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/benchmarl/models/common.py", line 160, in forward
tensordict = self._forward(tensordict)
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/benchmarl/models/common.py", line 234, in _forward
return self.models(tensordict)
File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/utils.py", line 373, in wrapper
result = func(_self, tensordict, *args, **kwargs)
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 624, in forward
tensordict_exec = self._run_module(
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 570, in _run_module
tensordict = module(tensordict, **kwargs)
File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/benchmarl/models/common.py", line 160, in forward
tensordict = self._forward(tensordict)
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/benchmarl/models/gnn.py", line 366, in _forward
graph = _batch_from_dense_to_ptg(
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/benchmarl/models/gnn.py", line 517, in _batch_from_dense_to_ptg
graphs.edge_index = torch_geometric.nn.pool.radius_graph(
^ Line in gnn.py causing the error.
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/torch_geometric/nn/pool/__init__.py", line 295, in radius_graph
return torch_cluster.radius_graph(x, r, batch, loop, max_num_neighbors,
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/torch_cluster/radius.py", line 135, in radius_graph
edge_index = radius(x, x, r, batch, batch,
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/torch_cluster/radius.py", line 82, in radius
return torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r,
File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/_ops.py", line 854, in __call__
return self_._op(*args, **(kwargs or {}))
RuntimeError: stack expects each tensor to be equal size, but got [2, 4303192] at entry 0 and [2, 4302544] at entry 1
^ The actual error.
There is also a warning printed every time:
[W BatchedFallback.cpp:81] Warning: There is a performance drop because we have not yet implemented the batching rule for torch_cluster::radius. Please file us an issue on GitHub so that we can prioritize its implementation. (function warnFallback)
I take it to mean that torch_cluster isn't meant to work with the BatchedTensor class.