Skip to content

GCN wrong code for SparseTensor #10047

@Qin87

Description

@Qin87

🐛 Describe the bug

GCNConv official code considers edge_index input to be both Tensor and SparseTensor. But their code is wrong for SparseTensor.

Here is the wrong code:
def forward(self, x: Tensor, edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:

    if isinstance(x, (tuple, list)):
        raise ValueError(f"'{self.__class__.__name__}' received a tuple "
                         f"of node features as input while this layer "
                         f"does not support bipartite message passing. "
                         f"Please try other layers such as 'SAGEConv' or "
                         f"'GraphConv' instead")

    if self.normalize:
        if isinstance(edge_index, Tensor):
            cache = self._cached_edge_index
            if cache is None:
                edge_index, edge_weight = gcn_norm(  # yapf: disable
                    edge_index, edge_weight, x.size(self.node_dim),
                    self.improved, self.add_self_loops, self.flow, x.dtype)
                if self.cached:
                    self._cached_edge_index = (edge_index, edge_weight)
            else:
                edge_index, edge_weight = cache[0], cache[1]

        elif isinstance(edge_index, SparseTensor):
            cache = self._cached_adj_t
            if cache is None:
                edge_index = gcn_norm(  # yapf: disable
                    edge_index, edge_weight, x.size(self.node_dim),
                    self.improved, self.add_self_loops, self.flow, x.dtype)
                if self.cached:
                    self._cached_adj_t = edge_index
            else:
                edge_index = cache

    x = self.lin(x)

    # propagate_type: (x: Tensor, edge_weight: OptTensor)
    out = self.propagate(edge_index, x=x, edge_weight=edge_weight)

    if self.bias is not None:
        out = out + self.bias

    return out

For SparseTensor, they didn't unwrap the SparseTensor back into Tensor edge_index and edge_weight, leading to wrong input for this code:
out = self.propagate(edge_index, x=x, edge_weight=edge_weight). Consequently, the experimental result will be wrong.

Here is my revised version for the SparseTensor case:

        elif isinstance(edge_index, SparseTensor):
            cache = self._cached_adj_t
            if cache is None:
                sparse_tensor = gcn_norm(  # yapf: disable
                    edge_index, edge_weight, x.size(self.node_dim),
                    self.improved, self.add_self_loops, self.flow, x.dtype)

                # Extract edge_index and edge_weight from SparseTensor
                row, col, edge_weight = sparse_tensor.coo()
                edge_index = torch.stack([row, col], dim=0)
                if self.cached:
                        self._cached_adj_t = (edge_index, edge_weight)
            else:
                edge_index, edge_weight = cache

Versions

all versions are wrong in this part

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions