|
16 | 16 |
|
17 | 17 | class PaddleWarpctc(PaddleDynamicAPIBenchmarkBase):
|
18 | 18 | def build_graph(self, config):
|
19 |
| - logits_length = self.variable(name='logits_length', shape=config.logits_length_shape, |
20 |
| - dtype=config.logits_length_dtype, |
| 19 | + log_probs = self.variable(name='log_probs', shape=config.log_probs_shape, dtype=config.log_probs_dtype) |
| 20 | + labels = self.variable(name='labels', shape=config.labels_shape, dtype=config.labels_dtype) |
| 21 | + input_lengths = self.variable(name='input_lengths', shape=config.input_lengths_shape, |
| 22 | + dtype=config.input_lengths_dtype, |
21 | 23 | value=np.array([config.max_seq_length]*config.batch_size).astype("int64"))
|
22 |
| - label_length = self.variable(name='label_length', shape=config.label_length_shape, |
23 |
| - dtype=config.label_length_dtype, |
| 24 | + label_lengths = self.variable(name='label_lengths', shape=config.label_lengths_shape, |
| 25 | + dtype=config.label_lengths_dtype, |
24 | 26 | value=np.array([config.max_label_length]*config.batch_size).astype("int64"))
|
25 |
| - logits = self.variable(name='logits', shape=config.logits_shape, dtype=config.logits_dtype) |
26 |
| - label = self.variable(name='label', shape=config.label_shape, dtype=config.label_dtype) |
27 |
| - result = paddle.fluid.layers.warpctc(input=logits, label=label, input_length=logits_length, |
28 |
| - label_length=label_length, blank=config.blank, norm_by_times=config.norm_by_times) |
29 |
| - self.feed_list = [logits_length, label_length, logits, label] |
| 27 | + |
| 28 | + result = paddle.nn.functional.ctc_loss(log_probs=log_probs, labels=labels, input_lengths=input_lengths, |
| 29 | + label_lengths=label_lengths, blank=config.blank, |
| 30 | + reduction=config.reduction, norm_by_times=config.norm_by_times) |
| 31 | + |
| 32 | + self.feed_list = [log_probs, labels, input_lengths, label_lengths] |
30 | 33 | self.fetch_list = [result]
|
31 | 34 | if config.backward:
|
32 |
| - self.append_gradients(result, [logits]) |
| 35 | + self.append_gradients(result, [log_probs]) |
33 | 36 |
|
34 | 37 | if __name__ == '__main__':
|
35 | 38 | test_main(
|
|
0 commit comments