Skip to content

GNN critic error during training #219

@ImNotRog

Description

@ImNotRog

I'm using a GNN as both actor and critic with VMAS, and I keep on running into a runtime error during training. Notably the error only occurs when the GNN is being used as a critic; running with a GNN actor and an MLP critic works just fine.

In particular the error comes from the line in models/gnn.py:

graphs.edge_index = torch_geometric.nn.pool.radius_graph(
    graphs.pos, batch=graphs.batch, r=edge_radius, loop=self_loops
)

The full stack trace is below. Both graphs.pos and graphs.batch are reasonable (they have shapes [660000, 2] and [660000], respectively, and expected values). The error occurs specifically when graph.pos is a BatchedTensor (an internal class due to a vmap call in ClipPPOLoss, see stack trace). BatchedTensor has unexpected behavior with .view and .reshape, and somehow torch_geometric.nn.pool.radius_graph isn't playing well with it.

Is there a fix for this problem? Has anyone gotten a GNN critic to work? Thanks!

Full stack trace:

File "/home/gridsan/rfan/VMAS-Navigation/vmas_test.py", line 94, in <module>
    experiment.run()
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/benchmarl/experiment/experiment.py", line 642, in run
    self._collection_loop()
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/benchmarl/experiment/experiment.py", line 720, in _collection_loop
    group_batch = self.algorithm.process_batch(group, group_batch)
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/benchmarl/algorithms/mappo.py", line 263, in process_batch
    loss.value_estimator(
  File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/torchrl/objectives/value/advantages.py", line 79, in new_func
    return fun(self, *args, **kwargs)
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/torchrl/objectives/value/advantages.py", line 68, in new_fun
    return fun(self, *args, **kwargs)
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/torchrl/objectives/value/advantages.py", line 1468, in forward
    value, next_value = self._call_value_nets(
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/torchrl/objectives/value/advantages.py", line 527, in _call_value_nets
    data_out = _vmap_func(

^ vmap call from loss.value_estimator.

  File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/_functorch/apis.py", line 188, in wrapped
    return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
  File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 281, in vmap_impl
    return _flat_vmap(
  File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 47, in fn
    return f(*args, **kwargs)
  File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 403, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/torchrl/objectives/utils.py", line 539, in decorated_module
    return module(*module_args)
  File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/utils.py", line 373, in wrapper
    result = func(_self, tensordict, *args, **kwargs)
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 624, in forward
    tensordict_exec = self._run_module(
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 570, in _run_module
    tensordict = module(tensordict, **kwargs)
  File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/benchmarl/models/common.py", line 160, in forward
    tensordict = self._forward(tensordict)
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/benchmarl/models/common.py", line 234, in _forward
    return self.models(tensordict)
  File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/utils.py", line 373, in wrapper
    result = func(_self, tensordict, *args, **kwargs)
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 624, in forward
    tensordict_exec = self._run_module(
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 570, in _run_module
    tensordict = module(tensordict, **kwargs)
  File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/benchmarl/models/common.py", line 160, in forward
    tensordict = self._forward(tensordict)
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/benchmarl/models/gnn.py", line 366, in _forward
    graph = _batch_from_dense_to_ptg(
File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/benchmarl/models/gnn.py", line 517, in _batch_from_dense_to_ptg
    graphs.edge_index = torch_geometric.nn.pool.radius_graph(

^ Line in gnn.py causing the error.

  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/torch_geometric/nn/pool/__init__.py", line 295, in radius_graph
    return torch_cluster.radius_graph(x, r, batch, loop, max_num_neighbors,
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/torch_cluster/radius.py", line 135, in radius_graph
    edge_index = radius(x, x, r, batch, batch,
  File "/home/gridsan/rfan/.local/lib/python3.10/site-packages/torch_cluster/radius.py", line 82, in radius
    return torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r,
  File "/state/partition1/llgrid/pkg/anaconda/python-ML-2024b/lib/python3.10/site-packages/torch/_ops.py", line 854, in __call__
    return self_._op(*args, **(kwargs or {}))
RuntimeError: stack expects each tensor to be equal size, but got [2, 4303192] at entry 0 and [2, 4302544] at entry 1

^ The actual error.

There is also a warning printed every time:

[W BatchedFallback.cpp:81] Warning: There is a performance drop because we have not yet implemented the batching rule for torch_cluster::radius. Please file us an issue on GitHub so that we can prioritize its implementation. (function warnFallback)

I take it to mean that torch_cluster isn't meant to work with the BatchedTensor class.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions