Skip to content

Conversation

DhyeyMavani2003
Copy link

Description

Implements to_batch_edge_index as the inverse operation of unbatch_edge_index. This function merges a list of edge_index tensors into a single batched edge_index tensor and returns the corresponding batch vector.

Closes #6099

Motivation

Currently, PyG provides unbatch_edge_index to split a batched edge_index into individual graphs, but lacks the inverse operation to merge multiple edge_index tensors into a batch. This function completes the API by providing the batching counterpart, enabling users to:

  • Manually construct batched graphs from individual edge_index tensors
  • Implement custom batching logic
  • Work with dynamic graph batching scenarios

Changes

Core Implementation

  • Added to_batch_edge_index() function in torch_geometric/utils/_unbatch.py
    • Takes a list of edge_index tensors as input
    • Returns a tuple of (batched_edge_index, batch_vector)
    • Properly offsets node indices for each graph
    • Handles edge cases: empty lists, empty graphs, mixed scenarios

API Updates

  • Exported to_batch_edge_index in torch_geometric/utils/__init__.py

Testing

  • Added 6 comprehensive test cases in test/utils/test_unbatch.py:
    • Basic functionality
    • Empty list handling
    • Single graph handling
    • Mixed empty/non-empty graphs
    • Roundtrip verification (proves it's the inverse of unbatch_edge_index)
    • Different sized graphs

Usage Example

import torch
from torch_geometric.utils import to_batch_edge_index, unbatch_edge_index

# Create individual edge_index tensors
edge_index_list = [
    torch.tensor([[0, 1, 1, 2, 2, 3],
                  [1, 0, 2, 1, 3, 2]]),
    torch.tensor([[0, 1, 1, 2],
                  [1, 0, 2, 1]]),
]

# Batch them together
edge_index, batch = to_batch_edge_index(edge_index_list)

print(edge_index)
# tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6],
#         [1, 0, 2, 1, 3, 2, 5, 4, 6, 5]])

print(batch)
# tensor([0, 0, 0, 0, 1, 1, 1])

# Verify roundtrip
unbatched = unbatch_edge_index(edge_index, batch)
assert all(torch.equal(a, b) for a, b in zip(edge_index_list, unbatched))

Testing

All tests pass:

pytest test/utils/test_unbatch.py -v
# 8 passed in 0.04s

Pre-commit hooks pass:

pre-commit run --all-files
# All checks passed

Checklist

  • Implementation follows PyG conventions
  • Comprehensive test coverage added
  • Documentation with examples included
  • Type hints provided
  • Pre-commit hooks pass (yapf, flake8, ruff)
  • Roundtrip test verifies correctness
  • Edge cases handled (empty lists, empty graphs)

Related Issues

Closes #6099

Implements to_batch_edge_index as the inverse of unbatch_edge_index.
This function merges a list of edge_index tensors into a single batched
edge_index tensor and returns the corresponding batch vector.

Features:
- Handles empty lists and empty graphs
- Properly offsets node indices for each graph
- Comprehensive test coverage including roundtrip tests
- Follows PyG conventions and code style

Closes pyg-team#6099

Co-authored-by: Ona <no-reply@ona.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support to_batch_edge_index

1 participant