1
+ #!/usr/bin/python
2
+ #-*-coding:utf-8-*-
3
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import sys
17
+ import numpy as np
18
+ import unittest
19
+
20
+ from pahelix .utils .splitters import \
21
+ RandomSplitter , IndexSplitter , ScaffoldSplitter , RandomScaffoldSplitter
22
+ from pahelix .datasets .inmemory_dataset import InMemoryDataset
23
+ from pahelix .featurizers .featurizer import Featurizer
24
+
25
+
26
+ class RandomSplitterTest (unittest .TestCase ):
27
+ def test_split (self ):
28
+ raw_data_list = [
29
+ {'smiles' : 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' },
30
+ {'smiles' : 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' },
31
+ {'smiles' : 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' },
32
+ {'smiles' : 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' },
33
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
34
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
35
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
36
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
37
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
38
+ {'smiles' : 'CCCCCCCCCCOCC(O)CN' },
39
+ {'smiles' : 'CCCCCCCCCCOCC(O)CN' },
40
+ {'smiles' : 'CCCCCCCCCCOCC(O)CN' },
41
+ {'smiles' : 'CCCCCCCCCCOCC(O)CN' },
42
+ ]
43
+ dataset = InMemoryDataset (raw_data_list )
44
+ splitter = RandomSplitter ()
45
+ train_dataset , valid_dataset , test_dataset = splitter .split (
46
+ dataset , frac_train = 0.34 , frac_valid = 0.33 , frac_test = 0.33 )
47
+ n = len (train_dataset ) + len (valid_dataset ) + len (test_dataset )
48
+ self .assertEqual (n , len (dataset ))
49
+
50
+
51
+ class IndexSplitterTest (unittest .TestCase ):
52
+ def test_split (self ):
53
+ raw_data_list = [
54
+ {'smiles' : 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' },
55
+ {'smiles' : 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' },
56
+ {'smiles' : 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' },
57
+ {'smiles' : 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' },
58
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
59
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
60
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
61
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
62
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
63
+ {'smiles' : 'CCCCCCCCCCOCC(O)CN' },
64
+ {'smiles' : 'CCCCCCCCCCOCC(O)CN' },
65
+ {'smiles' : 'CCCCCCCCCCOCC(O)CN' },
66
+ {'smiles' : 'CCCCCCCCCCOCC(O)CN' },
67
+ ]
68
+ dataset = InMemoryDataset (raw_data_list )
69
+ splitter = IndexSplitter ()
70
+ train_dataset , valid_dataset , test_dataset = splitter .split (
71
+ dataset , frac_train = 0.34 , frac_valid = 0.33 , frac_test = 0.33 )
72
+ n = len (train_dataset ) + len (valid_dataset ) + len (test_dataset )
73
+ self .assertEqual (n , len (dataset ))
74
+
75
+
76
+ class ScaffoldSplitterTest (unittest .TestCase ):
77
+ def test_split (self ):
78
+ raw_data_list = [
79
+ {'smiles' : 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' },
80
+ {'smiles' : 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' },
81
+ {'smiles' : 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' },
82
+ {'smiles' : 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' },
83
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
84
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
85
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
86
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
87
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
88
+ {'smiles' : 'CCCCCCCCCCOCC(O)CN' },
89
+ {'smiles' : 'CCCCCCCCCCOCC(O)CN' },
90
+ {'smiles' : 'CCCCCCCCCCOCC(O)CN' },
91
+ {'smiles' : 'CCCCCCCCCCOCC(O)CN' },
92
+ ]
93
+ dataset = InMemoryDataset (raw_data_list )
94
+ splitter = ScaffoldSplitter ()
95
+ train_dataset , valid_dataset , test_dataset = splitter .split (
96
+ dataset , frac_train = 0.34 , frac_valid = 0.33 , frac_test = 0.33 )
97
+ n = len (train_dataset ) + len (valid_dataset ) + len (test_dataset )
98
+ self .assertEqual (n , len (dataset ))
99
+
100
+
101
+ class RandomScaffoldSplitterTest (unittest .TestCase ):
102
+ def test_split (self ):
103
+ raw_data_list = [
104
+ {'smiles' : 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' },
105
+ {'smiles' : 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' },
106
+ {'smiles' : 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' },
107
+ {'smiles' : 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' },
108
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
109
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
110
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
111
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
112
+ {'smiles' : 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' },
113
+ {'smiles' : 'CCCCCCCCCCOCC(O)CN' },
114
+ {'smiles' : 'CCCCCCCCCCOCC(O)CN' },
115
+ {'smiles' : 'CCCCCCCCCCOCC(O)CN' },
116
+ {'smiles' : 'CCCCCCCCCCOCC(O)CN' },
117
+ ]
118
+ dataset = InMemoryDataset (raw_data_list )
119
+ splitter = RandomScaffoldSplitter ()
120
+ train_dataset , valid_dataset , test_dataset = splitter .split (
121
+ dataset , frac_train = 0.34 , frac_valid = 0.33 , frac_test = 0.33 )
122
+ n = len (train_dataset ) + len (valid_dataset ) + len (test_dataset )
123
+ self .assertEqual (n , len (dataset ))
124
+
125
+
126
+ if __name__ == '__main__' :
127
+ unittest .main ()
0 commit comments