Skip to content

Commit ee0caea

Browse files
authored
Merge pull request #118 from Yelrose/master
fixed subgraph edge_feat inheritance
2 parents 0da137f + 1577fb2 commit ee0caea

File tree

2 files changed

+13
-14
lines changed

2 files changed

+13
-14
lines changed

pgl/graph.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -589,26 +589,26 @@ def subgraph(self,
589589
if eid is None and edges is None:
590590
raise ValueError("Eid and edges can't be None at the same time.")
591591

592+
sub_edge_feat = {}
592593
if edges is None:
593594
edges = self._edges[eid]
594595
else:
595596
edges = np.array(edges, dtype="int64")
597+
598+
if with_edge_feat:
599+
for key, value in self._edge_feat.items():
600+
if eid is None:
601+
raise ValueError(
602+
"Eid can not be None with edge features.")
603+
sub_edge_feat[key] = value[eid]
604+
605+
if edge_feats is not None:
606+
sub_edge_feat.update(edge_feats)
596607

597608
sub_edges = graph_kernel.map_edges(
598609
np.arange(
599610
len(edges), dtype="int64"), edges, reindex)
600611

601-
sub_edge_feat = {}
602-
if edges is None:
603-
if with_edge_feat:
604-
for key, value in self._edge_feat.items():
605-
if eid is None:
606-
raise ValueError(
607-
"Eid can not be None with edge features.")
608-
sub_edge_feat[key] = value[eid]
609-
else:
610-
sub_edge_feat = edge_feats
611-
612612
sub_node_feat = {}
613613
if with_node_feat:
614614
for key, value in self._node_feat.items():

pgl/sample.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,6 @@ def graphsage_sample(graph, nodes, samples, ignore_edges=[]):
109109
start = time.time()
110110
# Find new nodes
111111

112-
feed_dict = {}
113-
114112
subgraphs = []
115113
for i in range(num_layers):
116114
subgraphs.append(
@@ -471,7 +469,8 @@ def pinsage_sample(graph,
471469
graph.subgraph(
472470
nodes=layer_nodes[0],
473471
edges=layer_edges[i],
474-
edge_feats=edge_feat_dict))
472+
edge_feats=edge_feat_dict,
473+
with_edge_feat=False))
475474
subgraphs[i].node_feat["index"] = np.array(
476475
layer_nodes[0], dtype="int64")
477476

0 commit comments

Comments
 (0)