Skip to content

Commit 8aef12c

Browse files
Merge pull request #173 from PaddlePaddle/dds_trans
Dds trans
2 parents 3368b93 + 4cebe18 commit 8aef12c

File tree

7 files changed

+585
-1
lines changed

7 files changed

+585
-1
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# DTSyn(Dual-Transformer neural network predicting Synergistic pairs)
2+
3+
[中文版本](./README_cn.md) [English Version](./README.md)
4+
5+
* [Background](#background)
6+
* [Dataset](#dataset)
7+
* [ddi](#ddi)
8+
* [lincs](#lincs)
9+
* [rna](#rna)
10+
* [Example](#example)
11+
* [training and evaluation](#training&evaluation)
12+
* [Reference](#reference)
13+
14+
## background
15+
Drug combinations, compared to monotherapies, have the potential to increase efficacy, reduce host toxicity and overcome drug resistance. However, screening novel synergistic drug pairs is challenging due to the enormous number of potential combination space. Further, lacking the understanding of mechanism of action (MoA) also limits the application of drug combinations. Our model utilizes different granularity level transformers to capture biological interactions from different dimensions.
16+
17+
## dataset
18+
drug combinations can be stored in directory `data`.
19+
### training data
20+
```sh
21+
cd data && "wget https://baidu-nlp.bj.bcebos.com/PaddleHelix/datasets/drug_synergy_datasets/DTSyn.tgz" && tar xzvf DTSyn.tgz
22+
```
23+
24+
## usage
25+
We use `main.py` for illustration,
26+
the cmdline is as follows:
27+
```
28+
CUDA_VISIBLE_DEVICES=0 python3 main.py
29+
--ddi ./data/ddi.csv
30+
--lincs ./data//gene_vector.csv
31+
--rna ./data/rna.csv
32+
--epochs 150
33+
```
34+
35+
## Reference
36+
**DTSyn**
37+
> @article{jing2022DTSyn,
38+
title={DTSyn: a dual-transformer-based neural network to predict synergistic drug combinations},
39+
author={Jing Hu, Jie Gao, Xiaomin Fang, Zijing Liu, Fan Wang, Weili Huang, Hua Wu, Guodong Zhao},
40+
journal={preprint on bioRxiv}
41+
}
42+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# DTSyn(Dual-Transformer neural network predicting Synergistic pairs)
2+
3+
[中文版本](./README_cn.md) [English Version](./README.md)
4+
5+
* [背景介绍](#背景介绍)
6+
* [数据集](#数据集)
7+
* [ddi](#ddi)
8+
* [lincs](#lincs)
9+
* [rna](#rna)
10+
* [使用说明](#使用说明)
11+
* [训练与评估](#训练与评估)
12+
* [引用](#引用)
13+
14+
## 背景
15+
药物联用可以解决单药使用面料的耐药,毒副作用过大等问题。当前双药联合使用同时还面临着组合爆炸,机理不明确等问题。本模型通过借助transformer结构从不同粒度出发捕获不同角度的生物学互作信息。
16+
## 数据集
17+
药物协同的分值文件放在 `data` 文件夹下。
18+
### 训练集
19+
```sh
20+
cd data && "wget https://baidu-nlp.bj.bcebos.com/PaddleHelix/datasets/drug_synergy_datasets/DTSyn.tgz" && tar xzvf DTSyn.tgz
21+
```
22+
23+
## 使用说明
24+
为了方便展示,我们构建了一个脚本, `main.py`
25+
用法如下:
26+
```
27+
CUDA_VISIBLE_DEVICES=0 python3 main.py
28+
--ddi ./data/ddi.csv
29+
--lincs ./data//gene_vector.csv
30+
--rna ./data/rna.csv
31+
--epochs 150
32+
```
33+
34+
## 引用
35+
**DTSyn**
36+
> @article{jing2022DTSyn,
37+
title={DTSyn: a dual-transformer-based neural network to predict synergistic drug combinations},
38+
author={Jing Hu, Jie Gao, Xiaomin Fang, Zijing Liu, Fan Wang, Weili Huang, Hua Wu, Guodong Zhao},
39+
journal={preprint on bioRxiv}
40+
}

apps/drug_drug_synergy/DTSyn/__init__.py

Whitespace-only changes.

apps/drug_drug_synergy/DTSyn/main.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
from pgl.utils.data import Dataset, Dataloader
17+
import argparse
18+
import sys
19+
#sys.path.append('.')
20+
from tsnet import TSNet
21+
22+
from utils_no_de import *
23+
from rdkit import Chem
24+
import pandas as pd
25+
import numpy as np
26+
27+
from sklearn.metrics import (roc_auc_score, average_precision_score, f1_score, roc_curve,
28+
precision_score, recall_score, auc, cohen_kappa_score,
29+
balanced_accuracy_score, precision_recall_curve, accuracy_score)
30+
from scipy.stats import pearsonr
31+
from sklearn.utils import shuffle
32+
33+
def train(model, data_loader, lincs, loss_fn, opt):
34+
total_pred, total_lb = [], []
35+
total_loss = []
36+
model.train()
37+
for g1, g2, gm1, gm2, cell, lbs in data_loader:
38+
g1 = g1.tensor()
39+
g2 = g2.tensor()
40+
gm1 = paddle.to_tensor(gm1, 'int64')
41+
gm2 = paddle.to_tensor(gm2, 'int64')
42+
cell = paddle.to_tensor(cell, 'float32')
43+
#dea = paddle.to_tensor(dea, 'float32')
44+
#deb = paddle.to_tensor(deb, 'float32')
45+
lbs = paddle.to_tensor(lbs, 'int64')
46+
#batch_samples = len(lbs)
47+
preds = model(g1, g2, gm1, gm2, cell, lincs, len(lbs))
48+
loss = loss_fn(preds, lbs)
49+
loss.backward()
50+
#print(preds.gradient())
51+
opt.step()
52+
opt.clear_grad()
53+
total_loss.append(loss.numpy())
54+
55+
return np.mean(total_loss)
56+
57+
def eva(model, data_loader, lincs, loss_fn):
58+
model.eval()
59+
total_pred, total_lb = [], []
60+
total_loss = []
61+
62+
for g1, g2, gm1, gm2, cell, lbs in data_loader:
63+
g1 = g1.tensor()
64+
g2 = g2.tensor()
65+
gm1 = paddle.to_tensor(gm1, 'int64')
66+
gm2 = paddle.to_tensor(gm2, 'int64')
67+
cell = paddle.to_tensor(cell, 'float32')
68+
69+
lbs = paddle.to_tensor(lbs, 'int64')
70+
#batch_samples = len(lbs)
71+
preds = model(g1, g2, gm1, gm2, cell, lincs, len(lbs))
72+
loss = loss_fn(preds, lbs)
73+
total_loss.append(loss.numpy())
74+
total_pred.append(preds.numpy())
75+
total_lb.append(lbs.numpy())
76+
total_pred = np.concatenate(total_pred, 0)
77+
total_lb = np.concatenate(total_lb, 0)
78+
79+
return total_pred, total_lb, np.mean(total_loss)
80+
81+
def test_auc(model, data_loader, lincs, criterion):
82+
test_pred, test_label, test_loss = eva(model, data_loader, lincs, criterion)
83+
test_prob = paddle.nn.functional.softmax(paddle.to_tensor(test_pred)).numpy()[:,1]
84+
pred_label = [1 if x > 0.5 else 0 for x in test_prob]
85+
ACC = accuracy_score(test_label, pred_label)
86+
BACC = balanced_accuracy_score(test_label, pred_label)
87+
PREC = precision_score(test_label, pred_label)
88+
TPR = recall_score(test_label, pred_label)
89+
KAPPA = cohen_kappa_score(test_label, pred_label)
90+
91+
precision, recall, threshold2 = precision_recall_curve(test_label, test_prob)
92+
return roc_auc_score(test_label, test_prob), auc(recall, precision), test_loss, ACC, BACC, PREC, TPR, KAPPA
93+
94+
95+
96+
def Pred(model, lincs, data_loader):
97+
model.eval()
98+
total_pred = []
99+
100+
for g1, g2, gm1, gm2, cell, lbs in data_loader:
101+
g1 = g1.tensor()
102+
g2 = g2.tensor()
103+
gm1 = paddle.to_tensor(gm1, 'int64')
104+
gm2 = paddle.to_tensor(gm2, 'int64')
105+
cell = paddle.to_tensor(cell, 'float32')
106+
107+
#lbs = paddle.to_tensor(lbs, 'int64')
108+
#batch_samples = len(lbs)
109+
preds = model(g1, g2, gm1, gm2, cell, lincs, len(lbs))
110+
111+
total_pred.append(preds.numpy())
112+
113+
total_pred = np.concatenate(total_pred, 0)
114+
total_prob = paddle.nn.functional.softmax(paddle.to_tensor(total_pred)).numpy()[:,1]
115+
116+
return total_prob
117+
118+
def main(args):
119+
"""
120+
Args:
121+
-ddi: drug drug synergy file.
122+
-rna: cell line gene expression file.
123+
-lincs: gene embeddings.
124+
-dropout: dropout rate for transformer blocks.
125+
-epochs: training epochs.
126+
-batch_size
127+
-lr: learning rate.
128+
129+
"""
130+
#paddle.set_device('cpu')
131+
ddi = pd.read_csv(args.ddi)
132+
rna = pd.read_csv(args.rna, index_col=0)
133+
lincs = pd.read_csv(args.lincs, index_col=0, header=None).values
134+
lincs = paddle.to_tensor(lincs, 'float32')
135+
136+
##############independent validation############
137+
#5-fold cross validation
138+
"""NUM_CROSS = 5
139+
ddi_shuffle = shuffle(ddi)
140+
data_size = len(ddi)
141+
fold_num = int(data_size / NUM_CROSS)
142+
for fold in range(NUM_CROSS):
143+
ddi_test = ddi_shuffle.iloc[fold*fold_num:fold_num * (fold + 1), :]
144+
ddi_train_before = ddi_shuffle.iloc[:fold*fold_num, :]
145+
ddi_train_after = ddi_shuffle.iloc[fold_num * (fold + 1):, :]
146+
ddi_train = pd.concat([ddi_train_before, ddi_train_after])"""
147+
148+
ddi_train = ddi.copy()
149+
train_cell = join_cell(ddi_train, rna)
150+
bt_tr = DDsData(ddi_train['drug1'].values,
151+
ddi_train['drug2'].values,
152+
train_cell,
153+
ddi_train['label'].values)
154+
155+
"""test_cell = join_cell(ddi_test, rna)
156+
#test_pta, test_ptb = join_pert(ddi_test, drugs_pert)
157+
bt_test = DDsData(ddi_test['drug1'].values,
158+
ddi_test['drug2'].values,
159+
test_cell,
160+
161+
ddi_test['label'].values)"""
162+
163+
164+
loader_tr = Dataloader(bt_tr, batch_size=args.batch_size, num_workers=4, collate_fn=collate)
165+
#loader_test = Dataloader(bt_test, batch_size=args.batch_size, num_workers=4, collate_fn=collate)
166+
#loader_val = Dataloader(bt_val, batch_size=args.batch_size, num_workers=1, collate_fn=collate)
167+
168+
model = TSNet(num_drug_feat=78,
169+
num_L_feat=978,
170+
num_cell_feat=rna.shape[1],
171+
num_drug_out=128,
172+
coarsed_heads=4,
173+
fined_heads=4,
174+
coarse_hidd=64,
175+
fine_hidd=64,
176+
dropout=args.dropout)
177+
opt = paddle.optimizer.Adam(learning_rate=args.lr, parameters=model.parameters())
178+
loss_fn = paddle.nn.CrossEntropyLoss()
179+
180+
for e in range(args.epochs):
181+
train_loss = train(model, loader_tr, lincs, loss_fn, opt)
182+
print('Epoch {}---training loss:{}'.format(e, train_loss))
183+
t_auc, test_prauc, test_loss, acc, bacc, prec, tpr, kappa = test_auc(model, loader_test, lincs, loss_fn)
184+
print('---Testing loss:{:.4f}, AUC:{:.4f}, PRAUC:{:.4f}, ACC:{:.4f}, BACC:{:.4f}, PREC:{:.4f}, TPR:{:.4f}, KAPPA:{:.4f}'
185+
.format(test_loss, t_auc, test_prauc, acc, bacc, prec, tpr, kappa))
186+
187+
#paddle.save(model.state_dict(), 'Results/xx.pdparams'.format(e+1))
188+
#model_params = paddle.load('Results/xx.pdparams')
189+
#model.set_state_dict(model_params)
190+
191+
192+
if __name__ == '__main__':
193+
parser = argparse.ArgumentParser()
194+
#parser.add_argument("--cuda", action='store_true', default=False)
195+
parser.add_argument("--dropout", type=float, default=0.6)
196+
parser.add_argument("--epochs", type=int, default=50)
197+
parser.add_argument("--batch_size", type=int, default=32)
198+
parser.add_argument("--lr", type=float, default=5e-6)
199+
parser.add_argument("--lincs", type=str, default='../data/gene_vector.csv')
200+
parser.add_argument("--ddi", type=str, help='using SMILES represent drugs', default='../data/ddi_dupave.csv')
201+
parser.add_argument("--ddi_test", type=str)
202+
parser.add_argument("--rna", type=str, default='../rna.csv')
203+
204+
args = parser.parse_args()
205+
print(args)
206+
main(args)

0 commit comments

Comments
 (0)