1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import paddle
16
+ from pgl .utils .data import Dataset , Dataloader
17
+ import argparse
18
+ import sys
19
+ #sys.path.append('.')
20
+ from tsnet import TSNet
21
+
22
+ from utils_no_de import *
23
+ from rdkit import Chem
24
+ import pandas as pd
25
+ import numpy as np
26
+
27
+ from sklearn .metrics import (roc_auc_score , average_precision_score , f1_score , roc_curve ,
28
+ precision_score , recall_score , auc , cohen_kappa_score ,
29
+ balanced_accuracy_score , precision_recall_curve , accuracy_score )
30
+ from scipy .stats import pearsonr
31
+ from sklearn .utils import shuffle
32
+
33
+ def train (model , data_loader , lincs , loss_fn , opt ):
34
+ total_pred , total_lb = [], []
35
+ total_loss = []
36
+ model .train ()
37
+ for g1 , g2 , gm1 , gm2 , cell , lbs in data_loader :
38
+ g1 = g1 .tensor ()
39
+ g2 = g2 .tensor ()
40
+ gm1 = paddle .to_tensor (gm1 , 'int64' )
41
+ gm2 = paddle .to_tensor (gm2 , 'int64' )
42
+ cell = paddle .to_tensor (cell , 'float32' )
43
+ #dea = paddle.to_tensor(dea, 'float32')
44
+ #deb = paddle.to_tensor(deb, 'float32')
45
+ lbs = paddle .to_tensor (lbs , 'int64' )
46
+ #batch_samples = len(lbs)
47
+ preds = model (g1 , g2 , gm1 , gm2 , cell , lincs , len (lbs ))
48
+ loss = loss_fn (preds , lbs )
49
+ loss .backward ()
50
+ #print(preds.gradient())
51
+ opt .step ()
52
+ opt .clear_grad ()
53
+ total_loss .append (loss .numpy ())
54
+
55
+ return np .mean (total_loss )
56
+
57
+ def eva (model , data_loader , lincs , loss_fn ):
58
+ model .eval ()
59
+ total_pred , total_lb = [], []
60
+ total_loss = []
61
+
62
+ for g1 , g2 , gm1 , gm2 , cell , lbs in data_loader :
63
+ g1 = g1 .tensor ()
64
+ g2 = g2 .tensor ()
65
+ gm1 = paddle .to_tensor (gm1 , 'int64' )
66
+ gm2 = paddle .to_tensor (gm2 , 'int64' )
67
+ cell = paddle .to_tensor (cell , 'float32' )
68
+
69
+ lbs = paddle .to_tensor (lbs , 'int64' )
70
+ #batch_samples = len(lbs)
71
+ preds = model (g1 , g2 , gm1 , gm2 , cell , lincs , len (lbs ))
72
+ loss = loss_fn (preds , lbs )
73
+ total_loss .append (loss .numpy ())
74
+ total_pred .append (preds .numpy ())
75
+ total_lb .append (lbs .numpy ())
76
+ total_pred = np .concatenate (total_pred , 0 )
77
+ total_lb = np .concatenate (total_lb , 0 )
78
+
79
+ return total_pred , total_lb , np .mean (total_loss )
80
+
81
+ def test_auc (model , data_loader , lincs , criterion ):
82
+ test_pred , test_label , test_loss = eva (model , data_loader , lincs , criterion )
83
+ test_prob = paddle .nn .functional .softmax (paddle .to_tensor (test_pred )).numpy ()[:,1 ]
84
+ pred_label = [1 if x > 0.5 else 0 for x in test_prob ]
85
+ ACC = accuracy_score (test_label , pred_label )
86
+ BACC = balanced_accuracy_score (test_label , pred_label )
87
+ PREC = precision_score (test_label , pred_label )
88
+ TPR = recall_score (test_label , pred_label )
89
+ KAPPA = cohen_kappa_score (test_label , pred_label )
90
+
91
+ precision , recall , threshold2 = precision_recall_curve (test_label , test_prob )
92
+ return roc_auc_score (test_label , test_prob ), auc (recall , precision ), test_loss , ACC , BACC , PREC , TPR , KAPPA
93
+
94
+
95
+
96
+ def Pred (model , lincs , data_loader ):
97
+ model .eval ()
98
+ total_pred = []
99
+
100
+ for g1 , g2 , gm1 , gm2 , cell , lbs in data_loader :
101
+ g1 = g1 .tensor ()
102
+ g2 = g2 .tensor ()
103
+ gm1 = paddle .to_tensor (gm1 , 'int64' )
104
+ gm2 = paddle .to_tensor (gm2 , 'int64' )
105
+ cell = paddle .to_tensor (cell , 'float32' )
106
+
107
+ #lbs = paddle.to_tensor(lbs, 'int64')
108
+ #batch_samples = len(lbs)
109
+ preds = model (g1 , g2 , gm1 , gm2 , cell , lincs , len (lbs ))
110
+
111
+ total_pred .append (preds .numpy ())
112
+
113
+ total_pred = np .concatenate (total_pred , 0 )
114
+ total_prob = paddle .nn .functional .softmax (paddle .to_tensor (total_pred )).numpy ()[:,1 ]
115
+
116
+ return total_prob
117
+
118
+ def main (args ):
119
+ """
120
+ Args:
121
+ -ddi: drug drug synergy file.
122
+ -rna: cell line gene expression file.
123
+ -lincs: gene embeddings.
124
+ -dropout: dropout rate for transformer blocks.
125
+ -epochs: training epochs.
126
+ -batch_size
127
+ -lr: learning rate.
128
+
129
+ """
130
+ #paddle.set_device('cpu')
131
+ ddi = pd .read_csv (args .ddi )
132
+ rna = pd .read_csv (args .rna , index_col = 0 )
133
+ lincs = pd .read_csv (args .lincs , index_col = 0 , header = None ).values
134
+ lincs = paddle .to_tensor (lincs , 'float32' )
135
+
136
+ ##############independent validation############
137
+ #5-fold cross validation
138
+ """NUM_CROSS = 5
139
+ ddi_shuffle = shuffle(ddi)
140
+ data_size = len(ddi)
141
+ fold_num = int(data_size / NUM_CROSS)
142
+ for fold in range(NUM_CROSS):
143
+ ddi_test = ddi_shuffle.iloc[fold*fold_num:fold_num * (fold + 1), :]
144
+ ddi_train_before = ddi_shuffle.iloc[:fold*fold_num, :]
145
+ ddi_train_after = ddi_shuffle.iloc[fold_num * (fold + 1):, :]
146
+ ddi_train = pd.concat([ddi_train_before, ddi_train_after])"""
147
+
148
+ ddi_train = ddi .copy ()
149
+ train_cell = join_cell (ddi_train , rna )
150
+ bt_tr = DDsData (ddi_train ['drug1' ].values ,
151
+ ddi_train ['drug2' ].values ,
152
+ train_cell ,
153
+ ddi_train ['label' ].values )
154
+
155
+ """test_cell = join_cell(ddi_test, rna)
156
+ #test_pta, test_ptb = join_pert(ddi_test, drugs_pert)
157
+ bt_test = DDsData(ddi_test['drug1'].values,
158
+ ddi_test['drug2'].values,
159
+ test_cell,
160
+
161
+ ddi_test['label'].values)"""
162
+
163
+
164
+ loader_tr = Dataloader (bt_tr , batch_size = args .batch_size , num_workers = 4 , collate_fn = collate )
165
+ #loader_test = Dataloader(bt_test, batch_size=args.batch_size, num_workers=4, collate_fn=collate)
166
+ #loader_val = Dataloader(bt_val, batch_size=args.batch_size, num_workers=1, collate_fn=collate)
167
+
168
+ model = TSNet (num_drug_feat = 78 ,
169
+ num_L_feat = 978 ,
170
+ num_cell_feat = rna .shape [1 ],
171
+ num_drug_out = 128 ,
172
+ coarsed_heads = 4 ,
173
+ fined_heads = 4 ,
174
+ coarse_hidd = 64 ,
175
+ fine_hidd = 64 ,
176
+ dropout = args .dropout )
177
+ opt = paddle .optimizer .Adam (learning_rate = args .lr , parameters = model .parameters ())
178
+ loss_fn = paddle .nn .CrossEntropyLoss ()
179
+
180
+ for e in range (args .epochs ):
181
+ train_loss = train (model , loader_tr , lincs , loss_fn , opt )
182
+ print ('Epoch {}---training loss:{}' .format (e , train_loss ))
183
+ t_auc , test_prauc , test_loss , acc , bacc , prec , tpr , kappa = test_auc (model , loader_test , lincs , loss_fn )
184
+ print ('---Testing loss:{:.4f}, AUC:{:.4f}, PRAUC:{:.4f}, ACC:{:.4f}, BACC:{:.4f}, PREC:{:.4f}, TPR:{:.4f}, KAPPA:{:.4f}'
185
+ .format (test_loss , t_auc , test_prauc , acc , bacc , prec , tpr , kappa ))
186
+
187
+ #paddle.save(model.state_dict(), 'Results/xx.pdparams'.format(e+1))
188
+ #model_params = paddle.load('Results/xx.pdparams')
189
+ #model.set_state_dict(model_params)
190
+
191
+
192
+ if __name__ == '__main__' :
193
+ parser = argparse .ArgumentParser ()
194
+ #parser.add_argument("--cuda", action='store_true', default=False)
195
+ parser .add_argument ("--dropout" , type = float , default = 0.6 )
196
+ parser .add_argument ("--epochs" , type = int , default = 50 )
197
+ parser .add_argument ("--batch_size" , type = int , default = 32 )
198
+ parser .add_argument ("--lr" , type = float , default = 5e-6 )
199
+ parser .add_argument ("--lincs" , type = str , default = '../data/gene_vector.csv' )
200
+ parser .add_argument ("--ddi" , type = str , help = 'using SMILES represent drugs' , default = '../data/ddi_dupave.csv' )
201
+ parser .add_argument ("--ddi_test" , type = str )
202
+ parser .add_argument ("--rna" , type = str , default = '../rna.csv' )
203
+
204
+ args = parser .parse_args ()
205
+ print (args )
206
+ main (args )
0 commit comments