Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions test/utils/test_num_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,31 @@ def test_maybe_num_nodes_dict():
'1': 3,
'2': 6,
}


def test_maybe_num_nodes_dict_empty_edge_index():
# Test with empty edge indices (regression test for bug fix)
edge_index_dict = {
('user', 'rates', 'movie'):
torch.tensor([[], []], dtype=torch.long),
('user', 'follows', 'user'):
torch.tensor([[0, 1], [1, 2]], dtype=torch.long),
}

result = maybe_num_nodes_dict(edge_index_dict)
assert result == {'user': 3, 'movie': 0}

# Test with all empty edge indices
edge_index_dict_all_empty = {
('user', 'rates', 'movie'): torch.tensor([[], []], dtype=torch.long),
('movie', 'in', 'genre'): torch.tensor([[], []], dtype=torch.long),
}

result_all_empty = maybe_num_nodes_dict(edge_index_dict_all_empty)
assert result_all_empty == {'user': 0, 'movie': 0, 'genre': 0}

# Test with provided num_nodes_dict and empty edges
num_nodes_dict = {'movie': 10}
result_with_provided = maybe_num_nodes_dict(edge_index_dict,
num_nodes_dict)
assert result_with_provided == {'user': 3, 'movie': 10}
6 changes: 4 additions & 2 deletions torch_geometric/utils/num_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ def maybe_num_nodes_dict(

key = keys[0]
if key not in found_types:
N = int(edge_index[0].max() + 1)
N = int(edge_index[0].max() +
1) if edge_index[0].numel() > 0 else 0
num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))

key = keys[-1]
if key not in found_types:
N = int(edge_index[1].max() + 1)
N = int(edge_index[1].max() +
1) if edge_index[1].numel() > 0 else 0
num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))

return num_nodes_dict