Skip to content

Commit 9e7a1f4

Browse files
authored
remove usused mem in latex_ocr head (PaddlePaddle#14803)
* adapt to npu * add device judgement
1 parent 78ec762 commit 9e7a1f4

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

ppocr/modeling/heads/rec_latexocr_head.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,10 @@ def forward(
786786
x, mask=mask, mems=mems, return_hiddens=True, seq_len=seq_len, **kwargs
787787
)
788788
x = self.norm(x)
789-
mem, x = x[:, :num_mem], x[:, num_mem:]
789+
if paddle.device.get_device().startswith("npu"):
790+
x = x[:, num_mem:]
791+
else:
792+
mem, x = x[:, :num_mem], x[:, num_mem:]
790793
out = self.to_logits(x) if not return_embeddings else x
791794
if return_mems:
792795
hiddens = intermediates.hiddens

0 commit comments

Comments
 (0)