Skip to content

Commit c9538c0

Browse files
authored
Merge pull request #29 from PaddlePaddle/pretrain_gnn
add unittest
2 parents dc63c94 + 272df79 commit c9538c0

File tree

6 files changed

+550
-305
lines changed

6 files changed

+550
-305
lines changed

apps/pretrained_compound/pretrain_gnns/README.md

+156-160
Large diffs are not rendered by default.

apps/pretrained_compound/pretrain_gnns/README_cn.md

+147-145
Large diffs are not rendered by default.

pahelix/tests/import_test.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#!/usr/bin/python
2+
#-*-coding:utf-8-*-
3+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
import unittest
17+
18+
class ImportTest(unittest.TestCase):
19+
def test_import_pahelix_alone(self):
20+
import pahelix
21+
22+
23+
if __name__ == '__main__':
24+
unittest.main()
+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#!/usr/bin/python
2+
#-*-coding:utf-8-*-
3+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
import sys
17+
import unittest
18+
from rdkit import Chem
19+
from rdkit.Chem import AllChem
20+
21+
from pahelix.utils.compound_tools import smiles_to_graph_data
22+
from pahelix.utils.compound_tools import mol_to_graph_data
23+
from pahelix.utils.compound_tools import get_gasteiger_partial_charges
24+
from pahelix.utils.compound_tools import create_standardized_mol_id
25+
from pahelix.utils.compound_tools import split_rdkit_mol_obj
26+
from pahelix.utils.compound_tools import CompoundConstants
27+
28+
29+
class CompoundToolsTest(unittest.TestCase):
30+
def test_mol_to_graph_data(self, add_self_loop=True):
31+
smiles ='CCOc1ccc2nc(S(N)(=O)=O)sc2c1'
32+
mol = AllChem.MolFromSmiles(smiles)
33+
data = mol_to_graph_data(mol)
34+
self.assertTrue(data)
35+
36+
def test_smiles_to_graph_data(self, add_self_loop=True):
37+
smiles ='CCOc1ccc2nc(S(N)(=O)=O)sc2c1'
38+
data = smiles_to_graph_data(smiles)
39+
self.assertTrue(data)
40+
41+
def test_get_gasteiger_partial_charges(self, n_iter=12):
42+
smiles ='CCOc1ccc2nc(S(N)(=O)=O)sc2c1'
43+
mol = AllChem.MolFromSmiles(smiles)
44+
charges = get_gasteiger_partial_charges(mol)
45+
self.assertEqual(len(charges), 16)
46+
47+
def test_create_standardized_mol_id(self):
48+
smiles ='CCOc1ccc2nc(S(N)(=O)=O)sc2c1'
49+
id1 = create_standardized_mol_id(smiles)
50+
id2 = create_standardized_mol_id(smiles)
51+
self.assertEqual(id1, id2)
52+
53+
54+
if __name__ == '__main__':
55+
unittest.main()
+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#!/usr/bin/python
2+
#-*-coding:utf-8-*-
3+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
import sys
17+
import numpy as np
18+
import unittest
19+
20+
from pahelix.utils.data_utils import load_npz_to_data_list
21+
from pahelix.utils.data_utils import save_data_list_to_npz
22+
23+
24+
class DataUtilsTest(unittest.TestCase):
25+
def test_data_list_to_npz(self):
26+
data_list = [
27+
{"a": np.array([1,23,4])},
28+
{"a": np.array([2,34,5])}
29+
]
30+
npz_file = 'tmp.npz'
31+
save_data_list_to_npz(data_list, npz_file)
32+
reload_data_list = load_npz_to_data_list(npz_file)
33+
self.assertEqual(len(data_list), len(reload_data_list))
34+
for d1, d2 in zip(data_list, reload_data_list):
35+
self.assertEqual(len(d1), len(d2))
36+
for key in d1:
37+
self.assertTrue((d1[key] == d2[key]).all())
38+
39+
40+
if __name__ == '__main__':
41+
unittest.main()

pahelix/utils/tests/splitters_test.py

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#!/usr/bin/python
2+
#-*-coding:utf-8-*-
3+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
import sys
17+
import numpy as np
18+
import unittest
19+
20+
from pahelix.utils.splitters import \
21+
RandomSplitter, IndexSplitter, ScaffoldSplitter, RandomScaffoldSplitter
22+
from pahelix.datasets.inmemory_dataset import InMemoryDataset
23+
from pahelix.featurizers.featurizer import Featurizer
24+
25+
26+
class RandomSplitterTest(unittest.TestCase):
27+
def test_split(self):
28+
raw_data_list = [
29+
{'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'},
30+
{'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'},
31+
{'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'},
32+
{'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'},
33+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
34+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
35+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
36+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
37+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
38+
{'smiles': 'CCCCCCCCCCOCC(O)CN'},
39+
{'smiles': 'CCCCCCCCCCOCC(O)CN'},
40+
{'smiles': 'CCCCCCCCCCOCC(O)CN'},
41+
{'smiles': 'CCCCCCCCCCOCC(O)CN'},
42+
]
43+
dataset = InMemoryDataset(raw_data_list)
44+
splitter = RandomSplitter()
45+
train_dataset, valid_dataset, test_dataset = splitter.split(
46+
dataset, frac_train=0.34, frac_valid=0.33, frac_test=0.33)
47+
n = len(train_dataset) + len(valid_dataset) + len(test_dataset)
48+
self.assertEqual(n, len(dataset))
49+
50+
51+
class IndexSplitterTest(unittest.TestCase):
52+
def test_split(self):
53+
raw_data_list = [
54+
{'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'},
55+
{'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'},
56+
{'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'},
57+
{'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'},
58+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
59+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
60+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
61+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
62+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
63+
{'smiles': 'CCCCCCCCCCOCC(O)CN'},
64+
{'smiles': 'CCCCCCCCCCOCC(O)CN'},
65+
{'smiles': 'CCCCCCCCCCOCC(O)CN'},
66+
{'smiles': 'CCCCCCCCCCOCC(O)CN'},
67+
]
68+
dataset = InMemoryDataset(raw_data_list)
69+
splitter = IndexSplitter()
70+
train_dataset, valid_dataset, test_dataset = splitter.split(
71+
dataset, frac_train=0.34, frac_valid=0.33, frac_test=0.33)
72+
n = len(train_dataset) + len(valid_dataset) + len(test_dataset)
73+
self.assertEqual(n, len(dataset))
74+
75+
76+
class ScaffoldSplitterTest(unittest.TestCase):
77+
def test_split(self):
78+
raw_data_list = [
79+
{'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'},
80+
{'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'},
81+
{'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'},
82+
{'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'},
83+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
84+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
85+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
86+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
87+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
88+
{'smiles': 'CCCCCCCCCCOCC(O)CN'},
89+
{'smiles': 'CCCCCCCCCCOCC(O)CN'},
90+
{'smiles': 'CCCCCCCCCCOCC(O)CN'},
91+
{'smiles': 'CCCCCCCCCCOCC(O)CN'},
92+
]
93+
dataset = InMemoryDataset(raw_data_list)
94+
splitter = ScaffoldSplitter()
95+
train_dataset, valid_dataset, test_dataset = splitter.split(
96+
dataset, frac_train=0.34, frac_valid=0.33, frac_test=0.33)
97+
n = len(train_dataset) + len(valid_dataset) + len(test_dataset)
98+
self.assertEqual(n, len(dataset))
99+
100+
101+
class RandomScaffoldSplitterTest(unittest.TestCase):
102+
def test_split(self):
103+
raw_data_list = [
104+
{'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'},
105+
{'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'},
106+
{'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'},
107+
{'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'},
108+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
109+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
110+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
111+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
112+
{'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'},
113+
{'smiles': 'CCCCCCCCCCOCC(O)CN'},
114+
{'smiles': 'CCCCCCCCCCOCC(O)CN'},
115+
{'smiles': 'CCCCCCCCCCOCC(O)CN'},
116+
{'smiles': 'CCCCCCCCCCOCC(O)CN'},
117+
]
118+
dataset = InMemoryDataset(raw_data_list)
119+
splitter = RandomScaffoldSplitter()
120+
train_dataset, valid_dataset, test_dataset = splitter.split(
121+
dataset, frac_train=0.34, frac_valid=0.33, frac_test=0.33)
122+
n = len(train_dataset) + len(valid_dataset) + len(test_dataset)
123+
self.assertEqual(n, len(dataset))
124+
125+
126+
if __name__ == '__main__':
127+
unittest.main()

0 commit comments

Comments
 (0)