Skip to content

Commit 21aae87

Browse files
authored
fix paddle zero output for release3.0b2 and develop (#1244)
1 parent cb36b86 commit 21aae87

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

paddlemix/models/llava/llava_arch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def prepare_inputs_labels_for_multimodal(
344344
continue
345345
image_token_indices = (
346346
[-1]
347-
+ paddle.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
347+
+ paddle.squeeze(paddle.where(cur_input_ids == IMAGE_TOKEN_INDEX)).unsqueeze(-1).tolist()
348348
+ [cur_input_ids.shape[0]]
349349
)
350350

0 commit comments

Comments
 (0)