Skip to content

Commit ecc6bc2

Browse files
committed
init update
1 parent b2c1678 commit ecc6bc2

12 files changed

+1603
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.DS_Store

README.md

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,39 @@
1-
# GraphATA
1+
# GraphATA
2+
Aggregate to Adapt: Node-Centric Aggregation for Multi-Source-Free Graph Domain Adaptation (WWW-2025).
3+
4+
![](https://github.com/cszhangzhen/GraphATA/blob/main/fig/model.png)
5+
6+
This is a PyTorch implementation of the GraphATA algorithm, which tries to address the multi-source domain adaptation problem without accessing the labelled source graph. Unlike previous multi-source domain adaptation approaches that aggregate predictions at model level, we introduce a novel model named GraphATA which conducts adaptation at node granularity. Specifically, we parameterize each node with its own graph convolutional matrix by automatically aggregating weight matrices from multiple source models according to its local context, thus realizing dynamic adaptation over graph structured data. We also demonstrate the capability of GraphATA to generalize to both model-centric and layer-centric methods.
7+
8+
## Requirements
9+
* python3.8
10+
* pytorch==1.13.1
11+
* torch-scatter==2.1.0
12+
* torch-sparse==0.6.15
13+
* torch-cluster==1.6.0
14+
* torch-geometric==2.4.0
15+
* numpy==1.23.4
16+
* scipy==1.9.3
17+
18+
## Datasets
19+
Datasets used in the paper are all publicly available datasets.
20+
21+
## Quick Start For Node Classification:
22+
Just execuate the following command for source model pre-training:
23+
```
24+
python train_source_node.py
25+
```
26+
Then, execuate the following command for adaptation:
27+
```
28+
python train_target_node.py
29+
```
30+
31+
## Quick Start For Graph Classification:
32+
Just execuate the following command for source model pre-training:
33+
```
34+
python train_source_graph.py
35+
```
36+
Then, execuate the following command for adaptation:
37+
```
38+
python train_target_graph.py
39+
```

data.zip

5.43 MB
Binary file not shown.

datasets.py

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
import os.path as osp
2+
import torch
3+
import numpy as np
4+
from torch_geometric.data import InMemoryDataset, Data
5+
from torch_geometric.io import read_txt_array
6+
import torch.nn.functional as F
7+
import random
8+
9+
import scipy
10+
import pickle as pkl
11+
from sklearn.preprocessing import label_binarize
12+
import csv
13+
import json
14+
15+
import warnings
16+
warnings.filterwarnings('ignore', category=DeprecationWarning)
17+
18+
19+
class CitationDataset(InMemoryDataset):
20+
def __init__(self,
21+
root,
22+
name,
23+
transform=None,
24+
pre_transform=None,
25+
pre_filter=None):
26+
self.name = name
27+
self.root = root
28+
super(CitationDataset, self).__init__(root, transform, pre_transform, pre_filter)
29+
30+
self.data, self.slices = torch.load(self.processed_paths[0])
31+
32+
@property
33+
def raw_file_names(self):
34+
return ["docs.txt", "edgelist.txt", "labels.txt"]
35+
36+
@property
37+
def processed_file_names(self):
38+
return ['data.pt']
39+
40+
def download(self):
41+
pass
42+
43+
def process(self):
44+
edge_path = osp.join(self.raw_dir, '{}_edgelist.txt'.format(self.name))
45+
edge_index = read_txt_array(edge_path, sep=',', dtype=torch.long).t()
46+
47+
docs_path = osp.join(self.raw_dir, '{}_docs.txt'.format(self.name))
48+
f = open(docs_path, 'rb')
49+
content_list = []
50+
for line in f.readlines():
51+
line = str(line, encoding="utf-8")
52+
content_list.append(line.split(","))
53+
x = np.array(content_list, dtype=float)
54+
x = torch.from_numpy(x).to(torch.float)
55+
56+
label_path = osp.join(self.raw_dir, '{}_labels.txt'.format(self.name))
57+
f = open(label_path, 'rb')
58+
content_list = []
59+
for line in f.readlines():
60+
line = str(line, encoding="utf-8")
61+
line = line.replace("\r", "").replace("\n", "")
62+
content_list.append(line)
63+
y = np.array(content_list, dtype=int)
64+
y = torch.from_numpy(y).to(torch.int64)
65+
66+
data_list = []
67+
data = Data(edge_index=edge_index, x=x, y=y)
68+
69+
random_node_indices = np.random.permutation(y.shape[0])
70+
training_size = int(len(random_node_indices) * 0.8)
71+
val_size = int(len(random_node_indices) * 0.1)
72+
train_node_indices = random_node_indices[:training_size]
73+
val_node_indices = random_node_indices[training_size:training_size + val_size]
74+
test_node_indices = random_node_indices[training_size + val_size:]
75+
76+
train_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
77+
train_masks[train_node_indices] = 1
78+
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
79+
val_masks[val_node_indices] = 1
80+
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
81+
test_masks[test_node_indices] = 1
82+
83+
data.train_mask = train_masks
84+
data.val_mask = val_masks
85+
data.test_mask = test_masks
86+
87+
if self.pre_transform is not None:
88+
data = self.pre_transform(data)
89+
90+
data_list.append(data)
91+
92+
data, slices = self.collate([data])
93+
94+
torch.save((data, slices), self.processed_paths[0])
95+
96+
97+
class TwitchDataset(InMemoryDataset):
98+
def __init__(self,
99+
root,
100+
name,
101+
transform=None,
102+
pre_transform=None,
103+
pre_filter=None):
104+
self.name = name
105+
self.root = root
106+
super(TwitchDataset, self).__init__(root, transform, pre_transform, pre_filter)
107+
108+
self.data, self.slices = torch.load(self.processed_paths[0])
109+
110+
@property
111+
def raw_file_names(self):
112+
return ["edges.csv, features.json, target.csv"]
113+
114+
@property
115+
def processed_file_names(self):
116+
return ['data.pt']
117+
118+
def download(self):
119+
pass
120+
121+
def load_twitch(self, lang):
122+
assert lang in ('DE', 'EN', 'ES', 'FR', 'PTBR', 'RU'), 'Invalid dataset'
123+
filepath = self.raw_dir
124+
label = []
125+
node_ids = []
126+
src = []
127+
targ = []
128+
uniq_ids = set()
129+
print(filepath)
130+
with open(f"{filepath}/musae_{lang}_target.csv", 'r') as f:
131+
reader = csv.reader(f)
132+
next(reader)
133+
for row in reader:
134+
node_id = int(row[5])
135+
# handle FR case of non-unique rows
136+
if node_id not in uniq_ids:
137+
uniq_ids.add(node_id)
138+
label.append(int(row[2]=="True"))
139+
node_ids.append(int(row[5]))
140+
141+
node_ids = np.array(node_ids, dtype=np.int)
142+
with open(f"{filepath}/musae_{lang}_edges.csv", 'r') as f:
143+
reader = csv.reader(f)
144+
next(reader)
145+
for row in reader:
146+
src.append(int(row[0]))
147+
targ.append(int(row[1]))
148+
with open(f"{filepath}/musae_{lang}_features.json", 'r') as f:
149+
j = json.load(f)
150+
src = np.array(src)
151+
targ = np.array(targ)
152+
label = np.array(label)
153+
inv_node_ids = {node_id:idx for (idx, node_id) in enumerate(node_ids)}
154+
reorder_node_ids = np.zeros_like(node_ids)
155+
for i in range(label.shape[0]):
156+
reorder_node_ids[i] = inv_node_ids[i]
157+
158+
n = label.shape[0]
159+
A = scipy.sparse.csr_matrix((np.ones(len(src)), (np.array(src), np.array(targ))), shape=(n,n))
160+
features = np.zeros((n,3170))
161+
for node, feats in j.items():
162+
if int(node) >= n:
163+
continue
164+
features[int(node), np.array(feats, dtype=int)] = 1
165+
# features = features[:, np.sum(features, axis=0) != 0] # remove zero cols. not need for cross graph task
166+
new_label = label[reorder_node_ids]
167+
label = new_label
168+
169+
return A, label, features
170+
171+
def process(self):
172+
A, label, features = self.load_twitch(self.name)
173+
A = A.todense() + A.todense().T
174+
edge_index = torch.tensor(np.array(A.nonzero()), dtype=torch.long)
175+
features = np.array(features)
176+
x = torch.from_numpy(features).to(torch.float)
177+
y = torch.from_numpy(label).to(torch.int64)
178+
179+
data_list = []
180+
data = Data(edge_index=edge_index, x=x, y=y)
181+
182+
random_node_indices = np.random.permutation(y.shape[0])
183+
training_size = int(len(random_node_indices) * 0.8)
184+
val_size = int(len(random_node_indices) * 0.1)
185+
train_node_indices = random_node_indices[:training_size]
186+
val_node_indices = random_node_indices[training_size:training_size + val_size]
187+
test_node_indices = random_node_indices[training_size + val_size:]
188+
189+
train_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
190+
train_masks[train_node_indices] = 1
191+
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
192+
val_masks[val_node_indices] = 1
193+
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
194+
test_masks[test_node_indices] = 1
195+
196+
data.train_mask = train_masks
197+
data.val_mask = val_masks
198+
data.test_mask = test_masks
199+
200+
if self.pre_transform is not None:
201+
data = self.pre_transform(data)
202+
203+
data_list.append(data)
204+
205+
data, slices = self.collate([data])
206+
207+
torch.save((data, slices), self.processed_paths[0])
208+
209+
210+
class CSBMDataset(InMemoryDataset):
211+
def __init__(self,
212+
root,
213+
name,
214+
transform=None,
215+
pre_transform=None,
216+
pre_filter=None):
217+
self.name = name
218+
self.root = root
219+
super(CSBMDataset, self).__init__(root, transform, pre_transform, pre_filter)
220+
221+
self.data, self.slices = torch.load(self.processed_paths[0])
222+
223+
@property
224+
def raw_file_names(self):
225+
return [".pkl"]
226+
227+
@property
228+
def processed_file_names(self):
229+
return ['data.pt']
230+
231+
def download(self):
232+
pass
233+
234+
def process(self):
235+
path = osp.join(self.raw_dir, '{}.pkl'.format(self.name))
236+
data = pkl.load(open(path, 'rb'))
237+
238+
data_list = []
239+
240+
random_node_indices = np.random.permutation(data.y.size(0))
241+
training_size = int(len(random_node_indices) * 0.8)
242+
val_size = int(len(random_node_indices) * 0.1)
243+
train_node_indices = random_node_indices[:training_size]
244+
val_node_indices = random_node_indices[training_size:training_size + val_size]
245+
test_node_indices = random_node_indices[training_size + val_size:]
246+
247+
train_masks = torch.zeros([data.y.size(0)], dtype=torch.bool)
248+
train_masks[train_node_indices] = 1
249+
val_masks = torch.zeros([data.y.size(0)], dtype=torch.bool)
250+
val_masks[val_node_indices] = 1
251+
test_masks = torch.zeros([data.y.size(0)], dtype=torch.bool)
252+
test_masks[test_node_indices] = 1
253+
254+
data.train_mask = train_masks
255+
data.val_mask = val_masks
256+
data.test_mask = test_masks
257+
258+
if self.pre_transform is not None:
259+
data = self.pre_transform(data)
260+
261+
data_list.append(data)
262+
263+
data, slices = self.collate([data])
264+
265+
torch.save((data, slices), self.processed_paths[0])
266+
267+
268+
class GraphTUDataset(InMemoryDataset):
269+
def __init__(self,
270+
root,
271+
name,
272+
transform=None,
273+
pre_transform=None,
274+
pre_filter=None):
275+
self.name = name
276+
self.root = root
277+
super(GraphTUDataset, self).__init__(root, transform, pre_transform, pre_filter)
278+
279+
self.data, self.slices = torch.load(self.processed_paths[0])
280+
281+
@property
282+
def raw_file_names(self):
283+
return [".pkl"]
284+
285+
@property
286+
def processed_file_names(self):
287+
return ['data.pt']
288+
289+
def download(self):
290+
pass
291+
292+
def process(self):
293+
path = osp.join(self.raw_dir, '{}.pkl'.format(self.name))
294+
data_list = pkl.load(open(path, 'rb'))
295+
random.shuffle(data_list)
296+
297+
if self.pre_transform is not None:
298+
data_list = [self.pre_transform(data) for data in data_list]
299+
300+
self.data, self.slices = self.collate(data_list)
301+
302+
torch.save((self.data, self.slices), self.processed_paths[0])

fig/model.png

155 KB
Loading

0 commit comments

Comments
 (0)