diff --git a/ppocr/modeling/heads/rec_latexocr_head.py b/ppocr/modeling/heads/rec_latexocr_head.py index 1b01d9d6dc6..edd27f3bf74 100644 --- a/ppocr/modeling/heads/rec_latexocr_head.py +++ b/ppocr/modeling/heads/rec_latexocr_head.py @@ -907,14 +907,7 @@ def generate( x = out[:, -self.max_seq_len :] mask = mask[:, -self.max_seq_len :] logits = self.net(x, mask=mask, **kwargs)[:, -1, :] - if filter_logits_fn in {top_k, top_p}: - filtered_logits = filter_logits_fn(logits, thres=filter_thres) - - probs = F.softmax(filtered_logits / temperature, axis=-1) - else: - raise NotImplementedError("The filter_logits_fn is not supported ") - - sample = paddle.multinomial(probs, 1) + sample = paddle.argmax(logits, axis=1).reshape([-1, 1]) out = paddle.concat((out, sample), axis=-1) pad_mask = paddle.full(shape=[mask.shape[0], 1], fill_value=1, dtype="bool") mask = paddle.concat((mask, pad_mask), axis=1) @@ -966,12 +959,7 @@ def generate_export( logits = self.net(x, mask=mask, context=context, seq_len=i_idx, **kwargs)[ :, -1, : ] - if filter_logits_fn in {top_k, top_p}: - filtered_logits = filter_logits_fn(logits, thres=filter_thres) - - probs = F.softmax(filtered_logits / temperature, axis=-1) - - sample = paddle.multinomial(probs, 1) + sample = paddle.argmax(logits, axis=1).reshape([-1, 1]) out = paddle.concat((out, sample), axis=-1) pad_mask = paddle.full(shape=[mask.shape[0], 1], fill_value=1, dtype="bool")