1
+ import os .path as osp
2
+ import torch
3
+ import numpy as np
4
+ from torch_geometric .data import InMemoryDataset , Data
5
+ from torch_geometric .io import read_txt_array
6
+ import torch .nn .functional as F
7
+ import random
8
+
9
+ import scipy
10
+ import pickle as pkl
11
+ from sklearn .preprocessing import label_binarize
12
+ import csv
13
+ import json
14
+
15
+ import warnings
16
+ warnings .filterwarnings ('ignore' , category = DeprecationWarning )
17
+
18
+
19
+ class CitationDataset (InMemoryDataset ):
20
+ def __init__ (self ,
21
+ root ,
22
+ name ,
23
+ transform = None ,
24
+ pre_transform = None ,
25
+ pre_filter = None ):
26
+ self .name = name
27
+ self .root = root
28
+ super (CitationDataset , self ).__init__ (root , transform , pre_transform , pre_filter )
29
+
30
+ self .data , self .slices = torch .load (self .processed_paths [0 ])
31
+
32
+ @property
33
+ def raw_file_names (self ):
34
+ return ["docs.txt" , "edgelist.txt" , "labels.txt" ]
35
+
36
+ @property
37
+ def processed_file_names (self ):
38
+ return ['data.pt' ]
39
+
40
+ def download (self ):
41
+ pass
42
+
43
+ def process (self ):
44
+ edge_path = osp .join (self .raw_dir , '{}_edgelist.txt' .format (self .name ))
45
+ edge_index = read_txt_array (edge_path , sep = ',' , dtype = torch .long ).t ()
46
+
47
+ docs_path = osp .join (self .raw_dir , '{}_docs.txt' .format (self .name ))
48
+ f = open (docs_path , 'rb' )
49
+ content_list = []
50
+ for line in f .readlines ():
51
+ line = str (line , encoding = "utf-8" )
52
+ content_list .append (line .split ("," ))
53
+ x = np .array (content_list , dtype = float )
54
+ x = torch .from_numpy (x ).to (torch .float )
55
+
56
+ label_path = osp .join (self .raw_dir , '{}_labels.txt' .format (self .name ))
57
+ f = open (label_path , 'rb' )
58
+ content_list = []
59
+ for line in f .readlines ():
60
+ line = str (line , encoding = "utf-8" )
61
+ line = line .replace ("\r " , "" ).replace ("\n " , "" )
62
+ content_list .append (line )
63
+ y = np .array (content_list , dtype = int )
64
+ y = torch .from_numpy (y ).to (torch .int64 )
65
+
66
+ data_list = []
67
+ data = Data (edge_index = edge_index , x = x , y = y )
68
+
69
+ random_node_indices = np .random .permutation (y .shape [0 ])
70
+ training_size = int (len (random_node_indices ) * 0.8 )
71
+ val_size = int (len (random_node_indices ) * 0.1 )
72
+ train_node_indices = random_node_indices [:training_size ]
73
+ val_node_indices = random_node_indices [training_size :training_size + val_size ]
74
+ test_node_indices = random_node_indices [training_size + val_size :]
75
+
76
+ train_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
77
+ train_masks [train_node_indices ] = 1
78
+ val_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
79
+ val_masks [val_node_indices ] = 1
80
+ test_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
81
+ test_masks [test_node_indices ] = 1
82
+
83
+ data .train_mask = train_masks
84
+ data .val_mask = val_masks
85
+ data .test_mask = test_masks
86
+
87
+ if self .pre_transform is not None :
88
+ data = self .pre_transform (data )
89
+
90
+ data_list .append (data )
91
+
92
+ data , slices = self .collate ([data ])
93
+
94
+ torch .save ((data , slices ), self .processed_paths [0 ])
95
+
96
+
97
+ class TwitchDataset (InMemoryDataset ):
98
+ def __init__ (self ,
99
+ root ,
100
+ name ,
101
+ transform = None ,
102
+ pre_transform = None ,
103
+ pre_filter = None ):
104
+ self .name = name
105
+ self .root = root
106
+ super (TwitchDataset , self ).__init__ (root , transform , pre_transform , pre_filter )
107
+
108
+ self .data , self .slices = torch .load (self .processed_paths [0 ])
109
+
110
+ @property
111
+ def raw_file_names (self ):
112
+ return ["edges.csv, features.json, target.csv" ]
113
+
114
+ @property
115
+ def processed_file_names (self ):
116
+ return ['data.pt' ]
117
+
118
+ def download (self ):
119
+ pass
120
+
121
+ def load_twitch (self , lang ):
122
+ assert lang in ('DE' , 'EN' , 'ES' , 'FR' , 'PTBR' , 'RU' ), 'Invalid dataset'
123
+ filepath = self .raw_dir
124
+ label = []
125
+ node_ids = []
126
+ src = []
127
+ targ = []
128
+ uniq_ids = set ()
129
+ print (filepath )
130
+ with open (f"{ filepath } /musae_{ lang } _target.csv" , 'r' ) as f :
131
+ reader = csv .reader (f )
132
+ next (reader )
133
+ for row in reader :
134
+ node_id = int (row [5 ])
135
+ # handle FR case of non-unique rows
136
+ if node_id not in uniq_ids :
137
+ uniq_ids .add (node_id )
138
+ label .append (int (row [2 ]== "True" ))
139
+ node_ids .append (int (row [5 ]))
140
+
141
+ node_ids = np .array (node_ids , dtype = np .int )
142
+ with open (f"{ filepath } /musae_{ lang } _edges.csv" , 'r' ) as f :
143
+ reader = csv .reader (f )
144
+ next (reader )
145
+ for row in reader :
146
+ src .append (int (row [0 ]))
147
+ targ .append (int (row [1 ]))
148
+ with open (f"{ filepath } /musae_{ lang } _features.json" , 'r' ) as f :
149
+ j = json .load (f )
150
+ src = np .array (src )
151
+ targ = np .array (targ )
152
+ label = np .array (label )
153
+ inv_node_ids = {node_id :idx for (idx , node_id ) in enumerate (node_ids )}
154
+ reorder_node_ids = np .zeros_like (node_ids )
155
+ for i in range (label .shape [0 ]):
156
+ reorder_node_ids [i ] = inv_node_ids [i ]
157
+
158
+ n = label .shape [0 ]
159
+ A = scipy .sparse .csr_matrix ((np .ones (len (src )), (np .array (src ), np .array (targ ))), shape = (n ,n ))
160
+ features = np .zeros ((n ,3170 ))
161
+ for node , feats in j .items ():
162
+ if int (node ) >= n :
163
+ continue
164
+ features [int (node ), np .array (feats , dtype = int )] = 1
165
+ # features = features[:, np.sum(features, axis=0) != 0] # remove zero cols. not need for cross graph task
166
+ new_label = label [reorder_node_ids ]
167
+ label = new_label
168
+
169
+ return A , label , features
170
+
171
+ def process (self ):
172
+ A , label , features = self .load_twitch (self .name )
173
+ A = A .todense () + A .todense ().T
174
+ edge_index = torch .tensor (np .array (A .nonzero ()), dtype = torch .long )
175
+ features = np .array (features )
176
+ x = torch .from_numpy (features ).to (torch .float )
177
+ y = torch .from_numpy (label ).to (torch .int64 )
178
+
179
+ data_list = []
180
+ data = Data (edge_index = edge_index , x = x , y = y )
181
+
182
+ random_node_indices = np .random .permutation (y .shape [0 ])
183
+ training_size = int (len (random_node_indices ) * 0.8 )
184
+ val_size = int (len (random_node_indices ) * 0.1 )
185
+ train_node_indices = random_node_indices [:training_size ]
186
+ val_node_indices = random_node_indices [training_size :training_size + val_size ]
187
+ test_node_indices = random_node_indices [training_size + val_size :]
188
+
189
+ train_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
190
+ train_masks [train_node_indices ] = 1
191
+ val_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
192
+ val_masks [val_node_indices ] = 1
193
+ test_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
194
+ test_masks [test_node_indices ] = 1
195
+
196
+ data .train_mask = train_masks
197
+ data .val_mask = val_masks
198
+ data .test_mask = test_masks
199
+
200
+ if self .pre_transform is not None :
201
+ data = self .pre_transform (data )
202
+
203
+ data_list .append (data )
204
+
205
+ data , slices = self .collate ([data ])
206
+
207
+ torch .save ((data , slices ), self .processed_paths [0 ])
208
+
209
+
210
+ class CSBMDataset (InMemoryDataset ):
211
+ def __init__ (self ,
212
+ root ,
213
+ name ,
214
+ transform = None ,
215
+ pre_transform = None ,
216
+ pre_filter = None ):
217
+ self .name = name
218
+ self .root = root
219
+ super (CSBMDataset , self ).__init__ (root , transform , pre_transform , pre_filter )
220
+
221
+ self .data , self .slices = torch .load (self .processed_paths [0 ])
222
+
223
+ @property
224
+ def raw_file_names (self ):
225
+ return [".pkl" ]
226
+
227
+ @property
228
+ def processed_file_names (self ):
229
+ return ['data.pt' ]
230
+
231
+ def download (self ):
232
+ pass
233
+
234
+ def process (self ):
235
+ path = osp .join (self .raw_dir , '{}.pkl' .format (self .name ))
236
+ data = pkl .load (open (path , 'rb' ))
237
+
238
+ data_list = []
239
+
240
+ random_node_indices = np .random .permutation (data .y .size (0 ))
241
+ training_size = int (len (random_node_indices ) * 0.8 )
242
+ val_size = int (len (random_node_indices ) * 0.1 )
243
+ train_node_indices = random_node_indices [:training_size ]
244
+ val_node_indices = random_node_indices [training_size :training_size + val_size ]
245
+ test_node_indices = random_node_indices [training_size + val_size :]
246
+
247
+ train_masks = torch .zeros ([data .y .size (0 )], dtype = torch .bool )
248
+ train_masks [train_node_indices ] = 1
249
+ val_masks = torch .zeros ([data .y .size (0 )], dtype = torch .bool )
250
+ val_masks [val_node_indices ] = 1
251
+ test_masks = torch .zeros ([data .y .size (0 )], dtype = torch .bool )
252
+ test_masks [test_node_indices ] = 1
253
+
254
+ data .train_mask = train_masks
255
+ data .val_mask = val_masks
256
+ data .test_mask = test_masks
257
+
258
+ if self .pre_transform is not None :
259
+ data = self .pre_transform (data )
260
+
261
+ data_list .append (data )
262
+
263
+ data , slices = self .collate ([data ])
264
+
265
+ torch .save ((data , slices ), self .processed_paths [0 ])
266
+
267
+
268
+ class GraphTUDataset (InMemoryDataset ):
269
+ def __init__ (self ,
270
+ root ,
271
+ name ,
272
+ transform = None ,
273
+ pre_transform = None ,
274
+ pre_filter = None ):
275
+ self .name = name
276
+ self .root = root
277
+ super (GraphTUDataset , self ).__init__ (root , transform , pre_transform , pre_filter )
278
+
279
+ self .data , self .slices = torch .load (self .processed_paths [0 ])
280
+
281
+ @property
282
+ def raw_file_names (self ):
283
+ return [".pkl" ]
284
+
285
+ @property
286
+ def processed_file_names (self ):
287
+ return ['data.pt' ]
288
+
289
+ def download (self ):
290
+ pass
291
+
292
+ def process (self ):
293
+ path = osp .join (self .raw_dir , '{}.pkl' .format (self .name ))
294
+ data_list = pkl .load (open (path , 'rb' ))
295
+ random .shuffle (data_list )
296
+
297
+ if self .pre_transform is not None :
298
+ data_list = [self .pre_transform (data ) for data in data_list ]
299
+
300
+ self .data , self .slices = self .collate (data_list )
301
+
302
+ torch .save ((self .data , self .slices ), self .processed_paths [0 ])
0 commit comments