Skip to content

Commit f74a1e0

Browse files
hedonglonghedonglong
hedonglong
authored and
hedonglong
committed
Add the processed data and checkpoint download link
1 parent edecd65 commit f74a1e0

File tree

6 files changed

+126
-51
lines changed

6 files changed

+126
-51
lines changed

apps/pretrained_compound/ChemRL/GEM-2/README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ You can download the PCQM4Mv2 dataset from ogb website:
3232

3333
https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m-v2.zip
3434

35+
# Processed Data
36+
You can download the processed PCQM4Mv2 dataset with rdkit generated 3d information from:
37+
https://baidu-nlp.bj.bcebos.com/PaddleHelix/datasets/compound_datasets/pcqm4mv2_gem2.tgz
38+
And then use tar to unzip the data.
39+
```bash
40+
mkdir -p ../data
41+
tar xzf pcqm4mv2_gem2.tgz -C ../data
42+
```
43+
3544
# How to run
3645
## Introduction to related configs
3746
You can adjsut the json files in the config folder to change the training settings.
@@ -58,6 +67,12 @@ The models will be saved under `./model`.
5867

5968
It will take around 60 mintues to finish one epoch on 16 A100 cards with total batch size of 512.
6069

70+
## Run inference
71+
To reproduce the result from the ogb leaderboard, you can download the checkponit from:
72+
https://baidu-nlp.bj.bcebos.com/PaddleHelix/models/molecular_modeling/gem2_l12_c256.pdparams
73+
Then put it under the local `./model` folder and run the inference command:
74+
sh scripts/inference.sh
75+
6176

6277
## Citing this work
6378

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#!/bin/bash
2+
cd $(dirname $0)
3+
cd ..
4+
5+
train_pcqm(){
6+
mkdir -p log/$exp_name model/$exp_name
7+
python3 train_gem2.py \
8+
--inference \
9+
--batch_size=$batch_size \
10+
--num_workers=10 \
11+
--max_epoch=150 \
12+
--dataset_config=$dataset_config \
13+
--data_cache_dir=$data_cache_dir \
14+
--model_config=$model_config \
15+
--encoder_config=$encoder_config \
16+
--train_config=$train_config \
17+
--init_model=$init_model \
18+
--start_step=$start_step \
19+
--model_dir=./model/$exp_name \
20+
--log_dir=./log/$exp_name
21+
}
22+
23+
24+
exp_name=""
25+
batch_size=32
26+
dataset_config="configs/dataset_configs/pcqmv2.json"
27+
data_cache_dir="../data/pcqm4m-v2-rdkit3d"
28+
model_config="configs/model_configs/mol_regr-optimus-mae.json"
29+
encoder_config="configs/model_configs/opt3d_l12_c256.json"
30+
init_model="./model/gem2_l12_c256.pdparams"
31+
train_config="configs/train_configs/lr4e-4-mid40.json"
32+
start_step=0
33+
train_pcqm

apps/pretrained_compound/ChemRL/GEM-2/scripts/train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ train_pcqm(){
2222

2323
}
2424

25-
25+
exp_name="gem2_l12_c256_4e-4"
2626
echo "$exp_name"
2727

2828

apps/pretrained_compound/ChemRL/GEM-2/src/config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ def make_updated_config(base_config, updated_dict):
8484

8585
"optimus_block_num": 12,
8686
"optimus_block": {
87-
"node_dropout_rate": 0.05,
87+
"first_body_axial_attention_dropout": 0.05,
8888
"pair_dropout_rate": 0.05,
89-
"node_attention": {
89+
"first_body_axial_attention": {
9090
"use_pair_layer_norm": True,
9191
"num_head": 8,
9292
"dropout_rate": 0.05,
@@ -96,15 +96,15 @@ def make_updated_config(base_config, updated_dict):
9696
"hidden_factor": 4,
9797
"dropout_rate": 0.1
9898
},
99-
"outer_product": {
99+
"low2high": {
100100
"inner_channel": 32
101101
},
102-
"triangle_attention_start_node": {
102+
"second_body_first_axis": {
103103
"num_head": 8,
104104
"dropout_rate": 0.05,
105105
"is_start": True
106106
},
107-
"triangle_attention_end_node": {
107+
"second_body_second_axis": {
108108
"num_head": 8,
109109
"dropout_rate": 0.05,
110110
"is_start": False

apps/pretrained_compound/ChemRL/GEM-2/src/optimus.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ def forward(self, batch):
122122
return results
123123

124124

125-
class NodeAttention(nn.Layer):
125+
class FirstBodyAxialAttention(nn.Layer):
126126
"""Compute self-attention over columns of a 2D input."""
127127
def __init__(self, model_config, global_config):
128-
super(NodeAttention, self).__init__()
128+
super(FirstBodyAxialAttention, self).__init__()
129129
self.model_config = model_config
130130

131131
node_channel = global_config.node_channel
@@ -253,9 +253,9 @@ def forward(self, x):
253253
return x
254254

255255

256-
class OuterProductMean(nn.Layer):
256+
class Low2HighModule(nn.Layer):
257257
def __init__(self, model_config, global_config):
258-
super(OuterProductMean, self).__init__()
258+
super(Low2HighModule, self).__init__()
259259
node_channel = global_config.node_channel
260260
pair_channel = global_config.pair_channel
261261
inner_channel = model_config.inner_channel
@@ -285,9 +285,9 @@ def forward(self, node_acts, node_mask):
285285
return act
286286

287287

288-
class TriangleAttentionWithAngle(nn.Layer):
288+
class SecondBodyAxialAttentionWithAngle(nn.Layer):
289289
def __init__(self, model_config, global_config):
290-
super(TriangleAttentionWithAngle, self).__init__()
290+
super(SecondBodyAxialAttentionWithAngle, self).__init__()
291291
pair_channel = global_config.pair_channel
292292
triple_channel = global_config.triple_channel
293293
self.num_head = model_config.num_head
@@ -342,9 +342,9 @@ def forward(self, pair_acts, triple_acts, bias):
342342
return out
343343

344344

345-
class TriangleAttentionWithAngleBias(nn.Layer):
345+
class SecondBodyAxialAttentionWithAngleBias(nn.Layer):
346346
def __init__(self, model_config, global_config):
347-
super(TriangleAttentionWithAngleBias, self).__init__()
347+
super(SecondBodyAxialAttentionWithAngleBias, self).__init__()
348348
pair_channel = global_config.pair_channel
349349
triple_channel = global_config.triple_channel
350350
self.num_head = model_config.num_head
@@ -395,15 +395,15 @@ def forward(self, pair_acts, triple_acts, bias):
395395
return out
396396

397397

398-
class TriangleAttention(nn.Layer):
398+
class SecondBodyAxialAttention(nn.Layer):
399399
def __init__(self, model_config, global_config):
400-
super(TriangleAttention, self).__init__()
400+
super(SecondBodyAxialAttention, self).__init__()
401401
self.is_start = model_config.is_start
402402

403403
if model_config.get('angle_as_bias', False):
404-
self.attn_mod = TriangleAttentionWithAngleBias(model_config, global_config)
404+
self.attn_mod = SecondBodyAxialAttentionWithAngleBias(model_config, global_config)
405405
else:
406-
self.attn_mod = TriangleAttentionWithAngle(model_config, global_config)
406+
self.attn_mod = SecondBodyAxialAttentionWithAngle(model_config, global_config)
407407

408408
def forward(self, pair_acts, triple_acts, triple_mask):
409409
"""
@@ -431,29 +431,29 @@ def __init__(self, model_config, global_config):
431431
pair_channel = global_config.pair_channel
432432

433433
### node track
434-
self.node_attn = NodeAttention(
435-
model_config.node_attention, global_config)
436-
self.node_attn_dropout = nn.Dropout(model_config.node_dropout_rate)
434+
self.first_body_axial_attention = FirstBodyAxialAttention(
435+
model_config.first_body_axial_attention, global_config)
436+
self.first_body_axial_attention_dropout = nn.Dropout(model_config.first_body_axial_attention_dropout)
437437

438438
self.node_ffn = FeedForwardNetwork(
439439
model_config.node_ffn, node_channel)
440-
self.node_ffn_dropout = nn.Dropout(model_config.node_dropout_rate)
440+
self.node_ffn_dropout = nn.Dropout(model_config.first_body_axial_attention_dropout)
441441

442-
### outer
443-
self.outer_product = OuterProductMean(
444-
model_config.outer_product, global_config)
445-
self.outer_product_dropout = nn.Dropout(model_config.pair_dropout_rate)
442+
### low2high
443+
self.low2high = Low2HighModule(
444+
model_config.low2high, global_config)
445+
self.low2high_dropout = nn.Dropout(model_config.pair_dropout_rate)
446446

447447
### pair track
448448
self.pair_before_ln = nn.LayerNorm(pair_channel)
449449

450-
self.triangle_attn_start = TriangleAttention(
451-
model_config.triangle_attention_start_node, global_config)
452-
self.triangle_attn_start_dropout = nn.Dropout(model_config.pair_dropout_rate)
450+
self.second_body_first_axis = SecondBodyAxialAttention(
451+
model_config.second_body_first_axis, global_config)
452+
self.second_body_first_axis_dropout = nn.Dropout(model_config.pair_dropout_rate)
453453

454-
self.triangle_attn_end = TriangleAttention(
455-
model_config.triangle_attention_end_node, global_config)
456-
self.triangle_attn_end_dropout = nn.Dropout(model_config.pair_dropout_rate)
454+
self.second_body_second_axis = SecondBodyAxialAttention(
455+
model_config.second_body_second_axis, global_config)
456+
self.second_body_second_axis_dropout = nn.Dropout(model_config.pair_dropout_rate)
457457

458458
self.pair_ffn = FeedForwardNetwork(
459459
model_config.pair_ffn, pair_channel)
@@ -476,24 +476,24 @@ def forward(self, node_acts, pair_acts, triple_acts, mask_dict):
476476
triple_mask = mask_dict['triple']
477477

478478
# node track
479-
residual = self.node_attn(node_acts, pair_acts, node_mask, pair_mask)
480-
node_acts += self.node_attn_dropout(residual)
479+
residual = self.first_body_axial_attention(node_acts, pair_acts, node_mask, pair_mask)
480+
node_acts += self.first_body_axial_attention_dropout(residual)
481481

482482
residual = self.node_ffn(node_acts)
483483
node_acts += self.node_ffn_dropout(residual)
484484

485485
# outer
486-
outer = self.outer_product(node_acts, node_mask)
487-
pair_acts += self.outer_product_dropout(outer)
486+
outer = self.low2high(node_acts, node_mask)
487+
pair_acts += self.low2high_dropout(outer)
488488

489489
# pair track
490490
pair_acts = self.pair_before_ln(pair_acts)
491491

492-
residual = self.triangle_attn_start(pair_acts, triple_acts, triple_mask)
493-
pair_acts += self.triangle_attn_start_dropout(residual)
492+
residual = self.second_body_first_axis(pair_acts, triple_acts, triple_mask)
493+
pair_acts += self.second_body_first_axis_dropout(residual)
494494

495-
residual = self.triangle_attn_end(pair_acts, triple_acts, triple_mask)
496-
pair_acts += self.triangle_attn_end_dropout(residual)
495+
residual = self.second_body_second_axis(pair_acts, triple_acts, triple_mask)
496+
pair_acts += self.second_body_second_axis_dropout(residual)
497497

498498
residual = self.pair_ffn(pair_acts)
499499
pair_acts += self.pair_ffn_dropout(residual)

apps/pretrained_compound/ChemRL/GEM-2/train_gem2.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,13 @@ def get_optimizer(args, train_config, model):
7575
return optimizer, scheduler
7676

7777

78-
def get_train_steps_per_epoch(dataset_len):
79-
if args.DEBUG:
80-
return 20
81-
# add as argument
82-
min_data_len = paddle.to_tensor(dataset_len)
83-
from paddle.distributed import ReduceOp
84-
dist.all_reduce(min_data_len, ReduceOp.MIN)
85-
dataset_len = min_data_len.numpy()[0]
86-
logging.info(f'min dataset len: {dataset_len}')
78+
def get_train_steps_per_epoch(dataset_len, args):
79+
if args.distributed:
80+
min_data_len = paddle.to_tensor(dataset_len)
81+
from paddle.distributed import ReduceOp
82+
dist.all_reduce(min_data_len, ReduceOp.MIN)
83+
dataset_len = min_data_len.numpy()[0]
84+
logging.info(f'min dataset len: {dataset_len}')
8785
return int(dataset_len / args.batch_size) - 5
8886

8987

@@ -135,7 +133,10 @@ def load_data(args, trainer_id, trainer_num, model_config, dataset_config, trans
135133
if args.DEBUG:
136134
train_npz_files = train_npz_files[:16]
137135
valid_npz_files = valid_npz_files[:8]
138-
train_dataset = InMemoryDataset(npz_data_files=train_npz_files[trainer_id::trainer_num])
136+
if args.inference:
137+
train_dataset = []
138+
else:
139+
train_dataset = InMemoryDataset(npz_data_files=train_npz_files[trainer_id::trainer_num])
139140
valid_dataset = InMemoryDataset(npz_data_files=valid_npz_files[trainer_id::trainer_num])
140141
if model_config.data.get('post_transform', False):
141142
logging.info('post transform')
@@ -227,6 +228,24 @@ def evaluate(args, epoch_id, model, test_dataset, collate_fn):
227228
return mean_mae
228229

229230

231+
def adjust_dropout(model_config, encoder_config, last_ck_path):
232+
"""
233+
adjust the dropout rate of the model to achieve better performance
234+
"""
235+
encoder_config.init_dropout_rate = 0
236+
encoder_config.optimus_block.first_body_axial_attention_dropout = 0
237+
encoder_config.optimus_block.pair_dropout_rate = 0
238+
encoder_config.optimus_block.first_body_axial_attention.dropout_rate = 0
239+
encoder_config.optimus_block.node_ffn.dropout_rate = 0
240+
encoder_config.optimus_block.second_body_first_axis.dropout_rate = 0
241+
encoder_config.optimus_block.second_body_second_axis.dropout_rate = 0
242+
encoder_config.optimus_block.pair_ffn.dropout_rate = 0
243+
model = MolRegressionModel(model_config, encoder_config)
244+
model.set_state_dict(paddle.load(last_ck_path))
245+
print('Load state_dict from %s' % last_ck_path)
246+
return model
247+
248+
230249
def main(args):
231250
"""
232251
Call the configuration function of the model, build the model and load data, then start training.
@@ -284,6 +303,10 @@ def _read_json(path):
284303
ema_start_step = 0 if args.DEBUG else 30
285304

286305
optimizer, scheduler = get_optimizer(args, train_config, model)
306+
if args.inference:
307+
mean_mae = evaluate(args, 0, model, valid_dataset, collate_fn)
308+
print(f"mean mae : {mean_mae}")
309+
exit(0)
287310

288311
### start train
289312
data_writer = None
@@ -292,7 +315,7 @@ def _read_json(path):
292315
data_writer = SummaryWriter(f'{args.log_dir}/tensorboard_log_dir', max_queue=0)
293316
except Exception as ex:
294317
print(f'Create data_writer failed: {ex}')
295-
train_steps = get_train_steps_per_epoch(len(train_dataset))
318+
train_steps = get_train_steps_per_epoch(len(train_dataset), args)
296319
print("train_steps per epoch : ", train_steps)
297320
mean_mae_list = []
298321
for _ in range(args.start_step):
@@ -301,6 +324,9 @@ def _read_json(path):
301324
## ema register
302325
if epoch_id >= ema_start_step and not ema.is_registered:
303326
ema.register()
327+
328+
if epoch_id == 69:
329+
model = adjust_dropout(model_config, encoder_config, f'./{args.model_dir}/epoch_{epoch_id - 1}.pdparams')
304330

305331
## train
306332
s_time = time.time()
@@ -362,6 +388,7 @@ def _read_json(path):
362388
parser.add_argument("--model_config", type=str)
363389
parser.add_argument("--encoder_config", type=str)
364390
parser.add_argument("--train_config", type=str)
391+
parser.add_argument("--inference", action='store_true', default=False)
365392
parser.add_argument("--init_model", type=str)
366393
parser.add_argument("--start_step", type=int)
367394
parser.add_argument("--model_dir", type=str, default="./debug_models")

0 commit comments

Comments
 (0)