Skip to content

Commit c11db9c

Browse files
authored
Merge pull request #107 from PaddlePaddle/demo
comments
2 parents d92160b + cf32e2f commit c11db9c

File tree

3 files changed

+20
-46
lines changed

3 files changed

+20
-46
lines changed

pahelix/model_zoo/sd_vae_model.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@
3535

3636

3737
class StateDecoder(nn.Layer):
38-
"""
39-
Description:
40-
encoder
38+
"""encoder
4139
4240
Args:
4341
max_len: the maximun length of input sequemce
@@ -79,9 +77,7 @@ def forward(self, z, n_steps=None):
7977

8078

8179
class PerpCalculator(nn.Layer):
82-
"""
83-
Description:
84-
loss type
80+
"""loss type
8581
8682
Args:
8783
true_binary: one-hot, with size=time_steps x bsize x DECISION_DIM
@@ -93,7 +89,7 @@ def __init__(self):
9389

9490
def forward(self, true_binary, rule_masks, raw_logits):
9591
"""
96-
tbd
92+
forward
9793
"""
9894
if cmd_args.loss_type == 'binary':
9995
exp_pred = paddle.exp(raw_logits) * rule_masks
@@ -124,9 +120,7 @@ def forward(self, true_binary, rule_masks, raw_logits):
124120

125121

126122
class MyPerpLoss(nn.Layer):
127-
"""
128-
Description:
129-
perplexity loss
123+
"""perplexity loss
130124
"""
131125
def __init__(self):
132126
super(MyPerpLoss, self).__init__()
@@ -154,9 +148,7 @@ def forward(self, true_binary, rule_masks, input_logits):
154148

155149

156150
class CNNEncoder(nn.Layer):
157-
"""
158-
Description:
159-
the encoder
151+
"""the encoder
160152
161153
Args:
162154
max_len: the maximum length of input
@@ -215,9 +207,7 @@ def get_encoder(model_config):
215207

216208

217209
class MolVAE(nn.Layer):
218-
"""
219-
Description:
220-
The Mol VAE model
210+
"""The Mol VAE model
221211
222212
Args:
223213
model_config: the model parameters

pahelix/model_zoo/seq_vae_model.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@
2828

2929

3030
class VAE(nn.Layer):
31-
"""
32-
Description:
33-
The sequence VAE model
31+
"""The sequence VAE model
3432
3533
Args:
3634
vocab: the vocab object.
@@ -148,15 +146,13 @@ def forward_decoder(self, x, z):
148146
return recon_loss
149147

150148
def sample_z_prior(self, n_batch):
151-
"""
152-
Description:
153-
Sampling z ~ p(z) = N(0, I)
149+
"""Sampling z ~ p(z) = N(0, I)
154150
155151
Args:
156152
n_batch: number of batches
157-
158-
Returns:
159-
(n_batch, d_z) of floats, sample of latent z.
153+
154+
Returns:
155+
(n_batch, d_z) of floats, sample of latent z
160156
"""
161157
return paddle.randn([n_batch, self.q_mu.weight.shape[1]])
162158

@@ -170,9 +166,7 @@ def tensor2string(self, tensor):
170166
return string
171167

172168
def sample(self, n_batch, max_len=100, z=None, temp=1.0):
173-
"""
174-
Description:
175-
Generating n_batch samples in eval mode (`z` could be
169+
"""Generating n_batch samples in eval mode (`z` could be
176170
not on same device)
177171
178172
Args:

pahelix/utils/metrics/molecular_generation/metrics_.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ def get_all_metrics(gen, k=None, n_jobs=1,
4747
test=None, test_scaffolds=None,
4848
ptest=None, ptest_scaffolds=None,
4949
train=None):
50-
"""
51-
Description:
52-
Computes all available metrics between test (scaffold test)
50+
"""Computes all available metrics between test (scaffold test)
5351
and generated sets of SMILES.
52+
53+
Description:
5454
Available metrics:
5555
* %valid
5656
* %unique@k
@@ -167,9 +167,7 @@ def get_all_metrics(gen, k=None, n_jobs=1,
167167

168168
def compute_intermediate_statistics(smiles, n_jobs=1, device='cpu',
169169
batch_size=512, pool=None):
170-
"""
171-
Description:
172-
The function precomputes statistics such as mean and variance for FCD, etc.
170+
""" The function precomputes statistics such as mean and variance for FCD, etc.
173171
It is useful to compute the statistics for test and scaffold test sets to
174172
speedup metrics calculation.
175173
"""
@@ -198,9 +196,7 @@ def compute_intermediate_statistics(smiles, n_jobs=1, device='cpu',
198196

199197

200198
def fraction_passes_filters(gen, n_jobs=1):
201-
"""
202-
Description:
203-
Computes the fraction of molecules that pass filters:
199+
"""Computes the fraction of molecules that pass filters:
204200
* MCF
205201
* PAINS
206202
* Only allowed atoms ('C','N','S','O','F','Cl','Br','H')
@@ -212,9 +208,7 @@ def fraction_passes_filters(gen, n_jobs=1):
212208

213209
def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
214210
gen_fps=None, p=1):
215-
"""
216-
Description:
217-
Computes internal diversity as:
211+
"""Computes internal diversity as:
218212
1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y))
219213
"""
220214
if gen_fps is None:
@@ -224,9 +218,7 @@ def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
224218

225219

226220
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
227-
"""
228-
Description:
229-
Computes a number of unique molecules
221+
"""Computes a number of unique molecules
230222
231223
Args:
232224
gen: list of SMILES
@@ -248,9 +240,7 @@ def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
248240

249241

250242
def fraction_valid(gen, n_jobs=1):
251-
"""
252-
Description:
253-
Computes a number of valid molecules
243+
"""Computes a number of valid molecules
254244
255245
Args:
256246
gen: list of SMILES

0 commit comments

Comments
 (0)