Skip to content

Commit dfe9d2a

Browse files
authored
Refine warpctc for benchmark (#1412)
* Refine warpctc for benchmark test=document_fix * refine var name * just for test=document_fix
1 parent 791cd50 commit dfe9d2a

File tree

2 files changed

+29
-22
lines changed

2 files changed

+29
-22
lines changed

api/dynamic_tests_v2/warpctc.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,23 @@
1616

1717
class PaddleWarpctc(PaddleDynamicAPIBenchmarkBase):
1818
def build_graph(self, config):
19-
logits_length = self.variable(name='logits_length', shape=config.logits_length_shape,
20-
dtype=config.logits_length_dtype,
19+
log_probs = self.variable(name='log_probs', shape=config.log_probs_shape, dtype=config.log_probs_dtype)
20+
labels = self.variable(name='labels', shape=config.labels_shape, dtype=config.labels_dtype)
21+
input_lengths = self.variable(name='input_lengths', shape=config.input_lengths_shape,
22+
dtype=config.input_lengths_dtype,
2123
value=np.array([config.max_seq_length]*config.batch_size).astype("int64"))
22-
label_length = self.variable(name='label_length', shape=config.label_length_shape,
23-
dtype=config.label_length_dtype,
24+
label_lengths = self.variable(name='label_lengths', shape=config.label_lengths_shape,
25+
dtype=config.label_lengths_dtype,
2426
value=np.array([config.max_label_length]*config.batch_size).astype("int64"))
25-
logits = self.variable(name='logits', shape=config.logits_shape, dtype=config.logits_dtype)
26-
label = self.variable(name='label', shape=config.label_shape, dtype=config.label_dtype)
27-
result = paddle.fluid.layers.warpctc(input=logits, label=label, input_length=logits_length,
28-
label_length=label_length, blank=config.blank, norm_by_times=config.norm_by_times)
29-
self.feed_list = [logits_length, label_length, logits, label]
27+
28+
result = paddle.nn.functional.ctc_loss(log_probs=log_probs, labels=labels, input_lengths=input_lengths,
29+
label_lengths=label_lengths, blank=config.blank,
30+
reduction=config.reduction, norm_by_times=config.norm_by_times)
31+
32+
self.feed_list = [log_probs, labels, input_lengths, label_lengths]
3033
self.fetch_list = [result]
3134
if config.backward:
32-
self.append_gradients(result, [logits])
35+
self.append_gradients(result, [log_probs])
3336

3437
if __name__ == '__main__':
3538
test_main(

api/tests_v2/configs/warpctc.json

+16-12
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,34 @@
1313
"type": "int32",
1414
"value": "200"
1515
},
16-
"logits_length": {
17-
"dtype": "int64",
18-
"shape": "[16L]",
19-
"type": "Variable"
20-
},
21-
"label_length": {
22-
"dtype": "int64",
23-
"shape": "[16L]",
24-
"type": "Variable"
25-
},
26-
"logits": {
16+
"log_probs": {
2717
"dtype": "float32",
2818
"shape": "[400L, 16L, 6L]",
2919
"type": "Variable"
3020
},
31-
"label": {
21+
"labels": {
3222
"dtype": "int32",
3323
"shape": "[16L, 200L]",
3424
"type": "Variable"
3525
},
26+
"input_lengths": {
27+
"dtype": "int64",
28+
"shape": "[16L]",
29+
"type": "Variable"
30+
},
31+
"label_lengths": {
32+
"dtype": "int64",
33+
"shape": "[16L]",
34+
"type": "Variable"
35+
},
3636
"blank": {
3737
"type": "int32",
3838
"value": "0"
3939
},
40+
"reduction": {
41+
"type": "string",
42+
"value": "none"
43+
},
4044
"norm_by_times": {
4145
"type": "bool",
4246
"value": "False"

0 commit comments

Comments
 (0)