Skip to content

Commit 1ccf688

Browse files
fix static train in formula (#14826)
1 parent 28657d4 commit 1ccf688

File tree

4 files changed

+19
-3
lines changed

4 files changed

+19
-3
lines changed

configs/rec/PP-FormuaNet/rec_pp_formulanet_l.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ Global:
1818
rec_char_dict_path: &rec_char_dict_path ppocr/utils/dict/unimernet_tokenizer
1919
max_new_tokens: &max_new_tokens 1024
2020
input_size: &input_size [768, 768]
21-
save_res_path: ./output/rec/predicts_unimernet_latexocr.txt
21+
save_res_path: ./output/rec/predicts_pp_formulanet_l.txt
2222
allow_resize_largeImg: False
2323
start_ema: True
24+
d2s_train_image_shape: [1,768,768]
2425

2526
Optimizer:
2627
name: AdamW

configs/rec/PP-FormuaNet/rec_pp_formulanet_s.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ Global:
1818
rec_char_dict_path: &rec_char_dict_path ppocr/utils/dict/unimernet_tokenizer
1919
max_new_tokens: &max_new_tokens 1024
2020
input_size: &input_size [384, 384]
21-
save_res_path: ./output/rec/predicts_unimernet_latexocr.txt
21+
save_res_path: ./output/rec/predicts_pp_formulanet_s.txt
2222
allow_resize_largeImg: False
2323
start_ema: True
24+
d2s_train_image_shape: [1,384,384]
2425

2526
Optimizer:
2627
name: AdamW

configs/rec/rec_unimernet.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ Global:
1818
rec_char_dict_path: &rec_char_dict_path ppocr/utils/dict/unimernet_tokenizer
1919
input_size: &input_size [192, 672]
2020
max_seq_len: &max_seq_len 1024
21-
save_res_path: ./output/rec/predicts_unimernet_plus_config_latexocr.txt
21+
save_res_path: ./output/rec/predicts_unimernet.txt
2222
allow_resize_largeImg: False
23+
d2s_train_image_shape: [1,192,672]
2324

2425
Optimizer:
2526
name: AdamW

ppocr/modeling/architectures/__init__.py

+13
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def apply_to_static(model, config, logger):
5050
"SVTR",
5151
"SVTR_HGNet",
5252
"LaTeXOCR",
53+
"UniMERNet",
54+
"PP-FormulaNet-S",
55+
"PP-FormulaNet-L",
5356
]
5457
if config["Architecture"]["algorithm"] in ["Distillation"]:
5558
algo = list(config["Architecture"]["Models"].values())[0]["algorithm"]
@@ -127,6 +130,16 @@ def apply_to_static(model, config, logger):
127130
InputSpec(shape=[None, None], dtype="float32"),
128131
]
129132
]
133+
elif algo in ["UniMERNet", "PP-FormulaNet-S", "PP-FormulaNet-L"]:
134+
specs = [
135+
[
136+
InputSpec(
137+
[None] + config["Global"]["d2s_train_image_shape"], dtype="float32"
138+
),
139+
InputSpec(shape=[None, None], dtype="float32"),
140+
InputSpec(shape=[None, None], dtype="float32"),
141+
]
142+
]
130143
model = to_static(model, input_spec=specs)
131144
logger.info("Successfully to apply @to_static with specs: {}".format(specs))
132145
return model

0 commit comments

Comments
 (0)