-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Open
Labels
Description
🐛 Describe the bug
GCNConv official code considers edge_index input to be both Tensor and SparseTensor. But their code is wrong for SparseTensor.
Here is the wrong code:
def forward(self, x: Tensor, edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:
if isinstance(x, (tuple, list)):
raise ValueError(f"'{self.__class__.__name__}' received a tuple "
f"of node features as input while this layer "
f"does not support bipartite message passing. "
f"Please try other layers such as 'SAGEConv' or "
f"'GraphConv' instead")
if self.normalize:
if isinstance(edge_index, Tensor):
cache = self._cached_edge_index
if cache is None:
edge_index, edge_weight = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim),
self.improved, self.add_self_loops, self.flow, x.dtype)
if self.cached:
self._cached_edge_index = (edge_index, edge_weight)
else:
edge_index, edge_weight = cache[0], cache[1]
elif isinstance(edge_index, SparseTensor):
cache = self._cached_adj_t
if cache is None:
edge_index = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim),
self.improved, self.add_self_loops, self.flow, x.dtype)
if self.cached:
self._cached_adj_t = edge_index
else:
edge_index = cache
x = self.lin(x)
# propagate_type: (x: Tensor, edge_weight: OptTensor)
out = self.propagate(edge_index, x=x, edge_weight=edge_weight)
if self.bias is not None:
out = out + self.bias
return out
For SparseTensor, they didn't unwrap the SparseTensor back into Tensor edge_index and edge_weight, leading to wrong input for this code:
out = self.propagate(edge_index, x=x, edge_weight=edge_weight). Consequently, the experimental result will be wrong.
Here is my revised version for the SparseTensor case:
elif isinstance(edge_index, SparseTensor):
cache = self._cached_adj_t
if cache is None:
sparse_tensor = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim),
self.improved, self.add_self_loops, self.flow, x.dtype)
# Extract edge_index and edge_weight from SparseTensor
row, col, edge_weight = sparse_tensor.coo()
edge_index = torch.stack([row, col], dim=0)
if self.cached:
self._cached_adj_t = (edge_index, edge_weight)
else:
edge_index, edge_weight = cache
Versions
all versions are wrong in this part