Skip to content

Commit d3b8627

Browse files
authored
[VLMs] split out "get placeholder mask" to helper (#39777)
* batch upidate all models * update * forgot about llava onevision * update * fix tests * delete file * typo * fix emu3 once and forever * update cohere2 vision as well
1 parent a115b67 commit d3b8627

File tree

52 files changed

+1369
-1068
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+1369
-1068
lines changed

src/transformers/models/aria/modeling_aria.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,30 @@ def get_image_features(
978978
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
979979
return image_features
980980

981+
def get_placeholder_mask(
982+
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
983+
):
984+
"""
985+
Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
986+
equal to the length of multimodal features. If the lengths are different, an error is raised.
987+
"""
988+
if input_ids is None:
989+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
990+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
991+
)
992+
special_image_mask = special_image_mask.all(-1)
993+
else:
994+
special_image_mask = input_ids == self.config.image_token_id
995+
996+
n_image_tokens = special_image_mask.sum()
997+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
998+
n_image_features = image_features.shape[0] * image_features.shape[1]
999+
if inputs_embeds[special_image_mask].numel() != image_features.numel():
1000+
raise ValueError(
1001+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
1002+
)
1003+
return special_image_mask
1004+
9811005
@can_return_tuple
9821006
@auto_docstring
9831007
def forward(
@@ -1007,29 +1031,15 @@ def forward(
10071031

10081032
# 2. Merge text and images
10091033
if pixel_values is not None and inputs_embeds.shape[1] != 1:
1010-
if input_ids is None:
1011-
special_image_mask = inputs_embeds == self.get_input_embeddings()(
1012-
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
1013-
)
1014-
special_image_mask = special_image_mask.all(-1)
1015-
else:
1016-
special_image_mask = input_ids == self.config.image_token_id
1017-
1018-
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
1019-
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
10201034
image_features = self.get_image_features(
10211035
pixel_values=pixel_values,
10221036
pixel_mask=pixel_mask,
10231037
vision_feature_layer=self.config.vision_feature_layer,
10241038
)
1025-
n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
1026-
n_image_features = n_images * n_features_per_image
1027-
if n_image_tokens != n_image_features:
1028-
raise ValueError(
1029-
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
1030-
)
1031-
10321039
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
1040+
special_image_mask = self._get_image_mask(
1041+
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
1042+
)
10331043
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
10341044

10351045
outputs = self.language_model(

src/transformers/models/aria/modular_aria.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,29 +1431,15 @@ def forward(
14311431

14321432
# 2. Merge text and images
14331433
if pixel_values is not None and inputs_embeds.shape[1] != 1:
1434-
if input_ids is None:
1435-
special_image_mask = inputs_embeds == self.get_input_embeddings()(
1436-
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
1437-
)
1438-
special_image_mask = special_image_mask.all(-1)
1439-
else:
1440-
special_image_mask = input_ids == self.config.image_token_id
1441-
1442-
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
1443-
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
14441434
image_features = self.get_image_features(
14451435
pixel_values=pixel_values,
14461436
pixel_mask=pixel_mask,
14471437
vision_feature_layer=self.config.vision_feature_layer,
14481438
)
1449-
n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
1450-
n_image_features = n_images * n_features_per_image
1451-
if n_image_tokens != n_image_features:
1452-
raise ValueError(
1453-
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
1454-
)
1455-
14561439
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
1440+
special_image_mask = self._get_image_mask(
1441+
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
1442+
)
14571443
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
14581444

14591445
outputs = self.language_model(

src/transformers/models/aya_vision/modeling_aya_vision.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
3333
from ...modeling_utils import PreTrainedModel
3434
from ...processing_utils import Unpack
35-
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling
35+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
3636
from ...utils.generic import check_model_inputs
3737
from ..auto import AutoModel
3838
from .configuration_aya_vision import AyaVisionConfig
@@ -242,6 +242,30 @@ def get_image_features(
242242
image_features = self.multi_modal_projector(selected_image_feature)
243243
return image_features
244244

245+
def get_placeholder_mask(
246+
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
247+
):
248+
"""
249+
Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
250+
equal to the length of multimodal features. If the lengths are different, an error is raised.
251+
"""
252+
if input_ids is None:
253+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
254+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
255+
)
256+
special_image_mask = special_image_mask.all(-1)
257+
else:
258+
special_image_mask = input_ids == self.config.image_token_id
259+
260+
n_image_tokens = special_image_mask.sum()
261+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
262+
n_image_features = image_features.shape[0] * image_features.shape[1]
263+
if inputs_embeds[special_image_mask].numel() != image_features.numel():
264+
raise ValueError(
265+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
266+
)
267+
return special_image_mask
268+
245269
@check_model_inputs
246270
@auto_docstring
247271
def forward(
@@ -279,24 +303,10 @@ def forward(
279303
vision_feature_layer=vision_feature_layer,
280304
vision_feature_select_strategy=vision_feature_select_strategy,
281305
)
282-
283-
if input_ids is None:
284-
special_image_mask = inputs_embeds == self.get_input_embeddings()(
285-
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
286-
)
287-
special_image_mask = special_image_mask.all(-1)
288-
else:
289-
special_image_mask = input_ids == self.config.image_token_id
290-
291-
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
292-
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
293-
294-
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
295-
n_image_features = image_features.shape[0] * image_features.shape[1]
296-
raise ValueError(
297-
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
298-
)
299306
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
307+
special_image_mask = self.get_placeholder_mask(
308+
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
309+
)
300310
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
301311

302312
outputs = self.language_model(

src/transformers/models/aya_vision/modular_aya_vision.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from ...cache_utils import Cache
3333
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3434
from ...processing_utils import Unpack
35-
from ...utils import auto_docstring, is_torchdynamo_compiling, logging
35+
from ...utils import auto_docstring, logging
3636
from ...utils.generic import check_model_inputs
3737
from .configuration_aya_vision import AyaVisionConfig
3838

@@ -200,24 +200,10 @@ def forward(
200200
vision_feature_layer=vision_feature_layer,
201201
vision_feature_select_strategy=vision_feature_select_strategy,
202202
)
203-
204-
if input_ids is None:
205-
special_image_mask = inputs_embeds == self.get_input_embeddings()(
206-
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
207-
)
208-
special_image_mask = special_image_mask.all(-1)
209-
else:
210-
special_image_mask = input_ids == self.config.image_token_id
211-
212-
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
213-
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
214-
215-
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
216-
n_image_features = image_features.shape[0] * image_features.shape[1]
217-
raise ValueError(
218-
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
219-
)
220203
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
204+
special_image_mask = self.get_placeholder_mask(
205+
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
206+
)
221207
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
222208

223209
outputs = self.language_model(

src/transformers/models/blip_2/modeling_blip_2.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,6 +1455,21 @@ def get_qformer_features(
14551455

14561456
return query_outputs
14571457

1458+
def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
1459+
"""
1460+
Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`.
1461+
"""
1462+
if input_ids is None:
1463+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
1464+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
1465+
)
1466+
special_image_mask = special_image_mask.all(-1)
1467+
else:
1468+
special_image_mask = input_ids == self.config.image_token_id
1469+
1470+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1471+
return special_image_mask
1472+
14581473
@auto_docstring
14591474
def forward(
14601475
self,
@@ -1545,16 +1560,8 @@ def forward(
15451560
if attention_mask is None:
15461561
attention_mask = torch.ones_like(input_ids)
15471562

1548-
if input_ids is None:
1549-
special_image_mask = inputs_embeds == self.get_input_embeddings()(
1550-
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
1551-
)
1552-
special_image_mask = special_image_mask.all(-1)
1553-
else:
1554-
special_image_mask = input_ids == self.config.image_token_id
1555-
1556-
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
15571563
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
1564+
special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
15581565
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
15591566
special_image_mask, language_model_inputs
15601567
)
@@ -1938,6 +1945,21 @@ def get_image_features(
19381945
return language_model_inputs, vision_outputs, query_outputs
19391946
return language_model_inputs
19401947

1948+
def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
1949+
"""
1950+
Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`.
1951+
"""
1952+
if input_ids is None:
1953+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
1954+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
1955+
)
1956+
special_image_mask = special_image_mask.all(-1)
1957+
else:
1958+
special_image_mask = input_ids == self.config.image_token_id
1959+
1960+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1961+
return special_image_mask
1962+
19411963
@auto_docstring
19421964
def forward(
19431965
self,
@@ -2042,16 +2064,8 @@ def forward(
20422064
if attention_mask is None:
20432065
attention_mask = torch.ones_like(input_ids)
20442066

2045-
if input_ids is None:
2046-
special_image_mask = inputs_embeds == self.get_input_embeddings()(
2047-
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
2048-
)
2049-
special_image_mask = special_image_mask.all(-1)
2050-
else:
2051-
special_image_mask = input_ids == self.config.image_token_id
2052-
2053-
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
20542067
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
2068+
special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
20552069
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
20562070
special_image_mask, language_model_inputs
20572071
)

src/transformers/models/chameleon/modeling_chameleon.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
TransformersKwargs,
3636
auto_docstring,
3737
can_return_tuple,
38-
is_torchdynamo_compiling,
3938
logging,
4039
)
4140
from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig
@@ -888,6 +887,30 @@ def get_image_features(self, pixel_values: torch.FloatTensor):
888887
vision_embeddings = self.get_input_embeddings()(image_tokens)
889888
return vision_embeddings
890889

890+
def get_placeholder_mask(
891+
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
892+
):
893+
"""
894+
Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
895+
equal to the length of multimodal features. If the lengths are different, an error is raised.
896+
"""
897+
if input_ids is None:
898+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
899+
torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
900+
)
901+
special_image_mask = special_image_mask.all(-1)
902+
else:
903+
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
904+
905+
n_image_tokens = special_image_mask.sum()
906+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
907+
n_image_features = image_features.shape[0] * image_features.shape[1]
908+
if inputs_embeds[special_image_mask].numel() != image_features.numel():
909+
raise ValueError(
910+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
911+
)
912+
return special_image_mask
913+
891914
@auto_docstring
892915
def forward(
893916
self,
@@ -924,23 +947,10 @@ def forward(
924947
inputs_embeds = self.embed_tokens(input_ids)
925948

926949
if pixel_values is not None:
927-
if input_ids is None:
928-
special_image_mask = inputs_embeds == self.get_input_embeddings()(
929-
torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
930-
)
931-
special_image_mask = special_image_mask.all(-1)
932-
else:
933-
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
934-
935-
n_image_tokens_in_text = special_image_mask.sum()
936-
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
937-
938950
image_embeds = self.get_image_features(pixel_values)
939-
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_embeds.numel():
940-
n_image_features = image_embeds.shape[0] * image_embeds.shape[1]
941-
raise ValueError(
942-
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
943-
)
951+
special_image_mask = self.get_placeholder_mask(
952+
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
953+
)
944954
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds)
945955

946956
# torch.jit.trace() doesn't support cache objects in the output

src/transformers/models/cohere2_vision/modeling_cohere2_vision.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,30 @@ def get_image_features(
197197
image_features = self.multi_modal_projector(selected_image_feature)
198198
return image_features
199199

200+
def get_placeholder_mask(
201+
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
202+
):
203+
"""
204+
Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
205+
equal to the length of multimodal features. If the lengths are different, an error is raised.
206+
"""
207+
if input_ids is None:
208+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
209+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
210+
)
211+
special_image_mask = special_image_mask.all(-1)
212+
else:
213+
special_image_mask = input_ids == self.config.image_token_id
214+
215+
n_image_tokens = special_image_mask.sum()
216+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
217+
n_image_features = image_features.shape[0] * image_features.shape[1]
218+
if inputs_embeds[special_image_mask].numel() != image_features.numel():
219+
raise ValueError(
220+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
221+
)
222+
return special_image_mask
223+
200224
@check_model_inputs
201225
@auto_docstring
202226
def forward(
@@ -225,16 +249,9 @@ def forward(
225249
if pixel_values is not None:
226250
image_features = self.get_image_features(pixel_values, image_num_patches=image_num_patches)
227251
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
228-
229-
if input_ids is None:
230-
special_image_mask = inputs_embeds == self.get_input_embeddings()(
231-
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
232-
)
233-
special_image_mask = special_image_mask.all(-1)
234-
else:
235-
special_image_mask = input_ids == self.config.image_token_id
236-
237-
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
252+
special_image_mask = self.get_placeholder_mask(
253+
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
254+
)
238255
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
239256

240257
outputs = self.language_model(

0 commit comments

Comments
 (0)