Skip to content

Commit 7a8567d

Browse files
committed
Rectify code of the LayoutLM series models and adjust it to amp_level mode.
1 parent c698646 commit 7a8567d

17 files changed

+69
-119
lines changed

configs/kie/layoutlmv3/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ Apart from the dataset setting, please also check the following important args:
144144
system:
145145
mode:
146146
distribute: False # `True` for distributed training, `False` for standalone training
147-
amp_level: 'O0'
147+
amp_level: 'O3'
148148
seed: 42
149149
val_while_train: True # Validate while training
150150
drop_overflow_update: False
@@ -157,7 +157,6 @@ model:
157157
name: TokenClassificationHead
158158
num_classes: 7
159159
use_visual_backbone: True
160-
use_float16: True
161160
pretrained:
162161
...
163162
train:

configs/kie/layoutlmv3/ser_layoutlmv3_xfund_zh.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
system:
22
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
33
distribute: False
4-
amp_level: "O0"
4+
amp_level: "O3"
55
seed: 42
66
log_interval: 10
77
val_start_epoch: 50
@@ -17,7 +17,6 @@ model:
1717
name: TokenClassificationHead
1818
num_classes: 7
1919
use_visual_backbone: True
20-
use_float16: True
2120
pretrained:
2221

2322
postprocess:

configs/kie/layoutxlm/re_layoutxlm_xfund_zh.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
system:
22
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
33
distribute: False
4-
amp_level: 'O0'
4+
amp_level: 'O3'
55
seed: 42
66
log_interval: 10
77
val_while_train: True
@@ -16,11 +16,9 @@ model:
1616
pretrained: True
1717
num_classes: &num_classes 7
1818
use_visual_backbone: True
19-
use_float16: True
2019
head:
2120
name: RelationExtractionHead
2221
use_visual_backbone: True
23-
use_float16: True
2422
pretrained:
2523

2624
postprocess:

configs/kie/layoutxlm/ser_layoutxlm_xfund_zh.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
system:
22
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
33
distribute: False
4-
amp_level: 'O0'
4+
amp_level: 'O3'
55
seed: 42
66
log_interval: 10
77
val_while_train: True
@@ -15,12 +15,10 @@ model:
1515
pretrained: True
1616
num_classes: &num_classes 7
1717
use_visual_backbone: True
18-
use_float16: True
1918
head :
2019
name: TokenClassificationHead
2120
num_classes: 7
2221
use_visual_backbone: True
23-
use_float16: True
2422
pretrained:
2523

2624
postprocess:

configs/kie/vi_layoutxlm/README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ Apart from the dataset setting, please also check the following important args:
159159
system:
160160
mode:
161161
distribute: False # `True` for distributed training, `False` for standalone training
162-
amp_level: 'O0'
162+
amp_level: 'O3'
163163
seed: 42
164164
val_while_train: True # Validate while training
165165
drop_overflow_update: False
@@ -171,12 +171,10 @@ model:
171171
pretrained: True
172172
num_classes: &num_classes 7
173173
use_visual_backbone: False
174-
use_float16: True
175174
head :
176175
name: TokenClassificationHead
177176
num_classes: 7
178177
use_visual_backbone: False
179-
use_float16: True
180178
pretrained:
181179
...
182180
train:

configs/kie/vi_layoutxlm/README_CN.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ eval:
156156
system:
157157
mode:
158158
distribute: False # 分布式训练为True,单卡训练为False
159-
amp_level: 'O0'
159+
amp_level: 'O3'
160160
seed: 42
161161
val_while_train: True # 边训练边验证
162162
drop_overflow_update: False
@@ -168,12 +168,10 @@ model:
168168
pretrained: True
169169
num_classes: &num_classes 7
170170
use_visual_backbone: False
171-
use_float16: True
172171
head :
173172
name: TokenClassificationHead
174173
num_classes: 7
175174
use_visual_backbone: False
176-
use_float16: True
177175
pretrained:
178176
...
179177
train:

configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
system:
22
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
33
distribute: False
4-
amp_level: "O0"
4+
amp_level: "O3"
55
seed: 42
66
log_interval: 10
77
val_while_train: True
@@ -16,11 +16,9 @@ model:
1616
pretrained: True
1717
num_classes: &num_classes 7
1818
use_visual_backbone: False
19-
use_float16: True
2019
head:
2120
name: RelationExtractionHead
2221
use_visual_backbone: False
23-
use_float16: True
2422
pretrained:
2523

2624
postprocess:

configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
system:
22
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
33
distribute: False
4-
amp_level: 'O0'
4+
amp_level: 'O3'
55
seed: 42
66
log_interval: 10
77
val_while_train: True
@@ -15,12 +15,10 @@ model:
1515
pretrained: True
1616
num_classes: &num_classes 7
1717
use_visual_backbone: False
18-
use_float16: True
1918
head :
2019
name: TokenClassificationHead
2120
num_classes: 7
2221
use_visual_backbone: False
23-
use_float16: True
2422
pretrained:
2523

2624
postprocess:

mindocr/models/backbones/layoutlmv3/configuration.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33

44
@dataclass
55
class LayoutLMv3PretrainedConfig:
6-
def __init__(self, use_float16=False):
6+
def __init__(self):
77
pretrained_config = {
8-
"use_float16": use_float16,
98
"fast_qkv": False,
109
"vocab_size": 250002,
1110
"hidden_size": 768,

mindocr/models/backbones/layoutlmv3/layoutlmv3.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def construct(
175175

176176
if attention_mask is not None:
177177
# Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
178-
attention_scores = attention_scores + attention_mask.astype(self.dense_dtype)
178+
attention_scores = attention_scores + attention_mask.astype(attention_scores.dtype)
179179

180180
# Normalize the attention scores to probabilities.
181181
# Use the trick of the CogView paper to stablize training
@@ -227,11 +227,8 @@ def __init__(self, config):
227227
self.has_relative_attention_bias = config.has_relative_attention_bias
228228
self.has_spatial_attention_bias = config.has_spatial_attention_bias
229229
self.patch_size = config.patch_size
230-
self.use_float16 = config.use_float16
231-
self.dense_dtype = mstype.float32
232-
if self.use_float16 is True:
233-
self.dense_dtype = mstype.float16
234-
self.min = finfo(self.dense_dtype)
230+
self.float32_min = finfo(mstype.float32)
231+
self.float16_min = finfo(mstype.float16)
235232
self.out_channels = 1
236233
self.use_visual_backbone = True
237234

@@ -342,7 +339,13 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape, dtype
342339
# Since we are adding it to the raw scores before the softmax, this is
343340
# effectively the same as removing these entirely. # fp16 compatibility
344341
extended_attention_mask = extended_attention_mask.astype(dtype)
345-
extended_attention_mask = (1.0 - extended_attention_mask) * self.min
342+
343+
if dtype == mstype.float32:
344+
minimum = self.float32_min
345+
elif dtype == mstype.float16:
346+
minimum = self.float16_min
347+
348+
extended_attention_mask = (1.0 - extended_attention_mask) * minimum
346349
return extended_attention_mask
347350

348351
def get_head_mask(self, head_mask, num_hidden_layers: int, is_attention_chunked: bool = False):
@@ -518,7 +521,7 @@ def construct(
518521

519522

520523
@register_backbone
521-
def layoutlmv3(use_float16: bool = True, **kwargs):
522-
pretrained_config = LayoutLMv3PretrainedConfig(use_float16)
524+
def layoutlmv3(**kwargs):
525+
pretrained_config = LayoutLMv3PretrainedConfig()
523526
model = LayoutLMv3Model(pretrained_config)
524527
return model

0 commit comments

Comments
 (0)