Skip to content

Commit f957410

Browse files
authored
Merge pull request #235 from Birdylx/fix_export
Fix ppyoloe_seg export
2 parents 286c104 + 8e4ec69 commit f957410

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

ppdet/engine/trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,10 @@ def export(self, output_dir='output_inference'):
11831183
if hasattr(self.cfg, 'export') and 'fuse_conv_bn' in self.cfg[
11841184
'export'] and self.cfg['export']['fuse_conv_bn']:
11851185
self.model = fuse_conv_bn(self.model)
1186+
1187+
# enable export_mode
1188+
for layer in self.model.sublayers(include_self=True):
1189+
layer.in_export_mode = True
11861190

11871191
model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
11881192
save_dir = os.path.join(output_dir, model_name)

ppdet/modeling/heads/ppyoloe_ins_head.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,8 +613,11 @@ def post_process(self,
613613
custom_ceil(mask_logits.shape[-1] / scale_factor[0][1])
614614
],
615615
mode='bilinear',
616-
align_corners=False)[..., :round(ori_h.item()), :round(ori_w.item())] # due to npu numeric error, we need to take round of img size.
617-
# align_corners=False)[..., :int(ori_h), :int(ori_w)] # TODO: only for export
616+
align_corners=False)
617+
if self.in_export_mode:
618+
mask_logits = mask_logits[..., :int(ori_h), :int(ori_w)]
619+
else:
620+
mask_logits = mask_logits[..., :round(ori_h.item()), :round(ori_w.item())] # due to npu numeric error, we need to take round of img size.
618621
masks = mask_logits.squeeze(0)
619622
mask_pred = masks > self.mask_thr_binary
620623

0 commit comments

Comments
 (0)