Skip to content

Commit 3d6e55c

Browse files
muellerzrArthurZucker
authored andcommitted
Fix model kwargs (#35875)
* Save state * Make a failing test * Better test * mpt -> done, many more to go * Rm extranious * Bamba * Bert * big_bird * biogpt * bloom * codegen * ctrl * data2vec * dbrx * Through up to Dbrx * electra * ernie * falcon * Fuyu/persimmon * Include noop kwargs to base models * Rebase * Skip musigen * Refactor/skip mllama * Revert makefile * Rm file * Fix PT failing, need to modify rest of loss funcs to not resize * Propagate some * Continue * More * More options * Mostly fixed * Proved that it's the same * Bloom is good * Make ability to override loss func possible * Fixup * Clean * Fix xglm * Quality tests * Skip OCR2 * Make specific loss for xglm * Make order the same/line up 1:1 * xglm * Skip fx output loss bloom model * Didn't pass in pad_token_id * Fix quality
1 parent 093bebc commit 3d6e55c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+365
-241
lines changed

src/transformers/modeling_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5142,6 +5142,9 @@ def tplize(mod: torch.nn.Module) -> None:
51425142

51435143
@property
51445144
def loss_function(self):
5145+
if hasattr(self, "_loss_function"):
5146+
return self._loss_function
5147+
51455148
loss_type = getattr(self, "loss_type", None)
51465149

51475150
if loss_type is None or loss_type not in LOSS_MAPPING:
@@ -5152,6 +5155,10 @@ def loss_function(self):
51525155
loss_type = "ForCausalLM"
51535156
return LOSS_MAPPING[loss_type]
51545157

5158+
@loss_function.setter
5159+
def loss_function(self, value):
5160+
self._loss_function = value
5161+
51555162
def get_compiled_call(self, compile_config: CompileConfig):
51565163
"""Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
51575164
non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't

src/transformers/models/bamba/modeling_bamba.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,6 +1200,7 @@ def forward(
12001200
output_hidden_states: Optional[bool] = None,
12011201
return_dict: Optional[bool] = None,
12021202
cache_position: Optional[torch.LongTensor] = None,
1203+
**kwargs, # NOOP kwargs, for now
12031204
) -> Union[Tuple, BaseModelOutputWithPast]:
12041205
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
12051206
output_hidden_states = (

src/transformers/models/bamba/modular_bamba.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,7 @@ def forward(
947947
output_hidden_states: Optional[bool] = None,
948948
return_dict: Optional[bool] = None,
949949
cache_position: Optional[torch.LongTensor] = None,
950+
**kwargs, # NOOP kwargs, for now
950951
) -> Union[Tuple, BaseModelOutputWithPast]:
951952
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
952953
output_hidden_states = (

src/transformers/models/bert_generation/modeling_bert_generation.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import torch
2121
import torch.utils.checkpoint
2222
from torch import nn
23-
from torch.nn import CrossEntropyLoss
2423

2524
from ...activations import ACT2FN
2625
from ...generation import GenerationMixin
@@ -734,6 +733,7 @@ def forward(
734733
output_attentions: Optional[bool] = None,
735734
output_hidden_states: Optional[bool] = None,
736735
return_dict: Optional[bool] = None,
736+
**kwargs, # NOOP kwargs, for now
737737
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
738738
r"""
739739
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -901,6 +901,7 @@ def forward(
901901
output_attentions: Optional[bool] = None,
902902
output_hidden_states: Optional[bool] = None,
903903
return_dict: Optional[bool] = None,
904+
**kwargs,
904905
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
905906
r"""
906907
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -963,18 +964,20 @@ def forward(
963964
output_attentions=output_attentions,
964965
output_hidden_states=output_hidden_states,
965966
return_dict=return_dict,
967+
**kwargs,
966968
)
967969

968970
sequence_output = outputs[0]
969971
prediction_scores = self.lm_head(sequence_output)
970972

971973
lm_loss = None
972974
if labels is not None:
973-
# we are doing next-token prediction; shift prediction scores and input ids by one
974-
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
975-
labels = labels[:, 1:].contiguous()
976-
loss_fct = CrossEntropyLoss()
977-
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
975+
lm_loss = self.loss_function(
976+
prediction_scores,
977+
labels,
978+
vocab_size=self.config.vocab_size,
979+
**kwargs,
980+
)
978981

979982
if not return_dict:
980983
output = (prediction_scores,) + outputs[1:]

src/transformers/models/big_bird/modeling_big_bird.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,6 +1983,7 @@ def forward(
19831983
output_attentions: Optional[bool] = None,
19841984
output_hidden_states: Optional[bool] = None,
19851985
return_dict: Optional[bool] = None,
1986+
**kwargs, # NOOP kwargs, for now
19861987
) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, Tuple[torch.FloatTensor]]:
19871988
r"""
19881989
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -2540,6 +2541,7 @@ def forward(
25402541
output_attentions: Optional[bool] = None,
25412542
output_hidden_states: Optional[bool] = None,
25422543
return_dict: Optional[bool] = None,
2544+
**kwargs,
25432545
) -> Union[CausalLMOutputWithCrossAttentions, Tuple[torch.FloatTensor]]:
25442546
r"""
25452547
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -2580,18 +2582,20 @@ def forward(
25802582
output_attentions=output_attentions,
25812583
output_hidden_states=output_hidden_states,
25822584
return_dict=return_dict,
2585+
**kwargs,
25832586
)
25842587

25852588
sequence_output = outputs[0]
25862589
prediction_scores = self.cls(sequence_output)
25872590

25882591
lm_loss = None
25892592
if labels is not None:
2590-
# we are doing next-token prediction; shift prediction scores and input ids by one
2591-
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
2592-
labels = labels[:, 1:].contiguous()
2593-
loss_fct = CrossEntropyLoss()
2594-
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
2593+
lm_loss = self.loss_function(
2594+
prediction_scores,
2595+
labels,
2596+
vocab_size=self.config.vocab_size,
2597+
**kwargs,
2598+
)
25952599

25962600
if not return_dict:
25972601
output = (prediction_scores,) + outputs[2:]

src/transformers/models/biogpt/modeling_biogpt.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ def forward(
588588
output_attentions: Optional[bool] = None,
589589
output_hidden_states: Optional[bool] = None,
590590
return_dict: Optional[bool] = None,
591+
**kwargs, # NOOP kwargs, for now
591592
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
592593
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
593594
output_hidden_states = (
@@ -757,6 +758,7 @@ def forward(
757758
output_attentions: Optional[bool] = None,
758759
output_hidden_states: Optional[bool] = None,
759760
return_dict: Optional[bool] = None,
761+
**kwargs,
760762
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
761763
r"""
762764
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -783,11 +785,12 @@ def forward(
783785

784786
lm_loss = None
785787
if labels is not None:
786-
# we are doing next-token prediction; shift prediction scores and input ids by one
787-
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
788-
labels = labels[:, 1:].contiguous()
789-
loss_fct = CrossEntropyLoss()
790-
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
788+
lm_loss = self.loss_function(
789+
prediction_scores,
790+
labels,
791+
vocab_size=self.config.vocab_size,
792+
**kwargs,
793+
)
791794

792795
if not return_dict:
793796
output = (prediction_scores,) + outputs[1:]

src/transformers/models/bloom/modeling_bloom.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,8 @@ def forward(
958958
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
959959
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
960960
"""
961+
# Bloom has deprecated kwargs, so we need to pop num_items_in_batch explicitly
962+
num_items_in_batch = deprecated_arguments.pop("num_items_in_batch", None)
961963
if deprecated_arguments.pop("position_ids", False) is not False:
962964
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
963965
warnings.warn(
@@ -990,14 +992,12 @@ def forward(
990992
if labels is not None:
991993
# move labels to correct device to enable model parallelism
992994
labels = labels.to(lm_logits.device)
993-
# Shift so that tokens < n predict n
994-
shift_logits = lm_logits[..., :-1, :].contiguous()
995-
shift_labels = labels[..., 1:].contiguous()
996-
batch_size, seq_length, vocab_size = shift_logits.shape
997995
# Flatten the tokens
998-
loss_fct = CrossEntropyLoss()
999-
loss = loss_fct(
1000-
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
996+
loss = self.loss_function(
997+
lm_logits,
998+
labels,
999+
vocab_size=self.config.vocab_size,
1000+
num_items_in_batch=num_items_in_batch,
10011001
)
10021002

10031003
if not return_dict:

src/transformers/models/camembert/modeling_camembert.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1584,6 +1584,7 @@ def forward(
15841584
output_attentions: Optional[bool] = None,
15851585
output_hidden_states: Optional[bool] = None,
15861586
return_dict: Optional[bool] = None,
1587+
**kwargs,
15871588
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
15881589
r"""
15891590
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -1655,11 +1656,12 @@ def forward(
16551656
if labels is not None:
16561657
# move labels to correct device to enable model parallelism
16571658
labels = labels.to(prediction_scores.device)
1658-
# we are doing next-token prediction; shift prediction scores and input ids by one
1659-
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1660-
labels = labels[:, 1:].contiguous()
1661-
loss_fct = CrossEntropyLoss()
1662-
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1659+
lm_loss = self.loss_function(
1660+
prediction_scores,
1661+
labels,
1662+
vocab_size=self.config.vocab_size,
1663+
**kwargs,
1664+
)
16631665

16641666
if not return_dict:
16651667
output = (prediction_scores,) + outputs[2:]

src/transformers/models/codegen/modeling_codegen.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import torch
2020
import torch.utils.checkpoint
2121
from torch import nn
22-
from torch.nn import CrossEntropyLoss
2322

2423
from ...activations import ACT2FN
2524
from ...cache_utils import Cache, DynamicCache, StaticCache
@@ -450,6 +449,7 @@ def forward(
450449
output_hidden_states: Optional[bool] = None,
451450
return_dict: Optional[bool] = None,
452451
cache_position: Optional[torch.LongTensor] = None,
452+
**kwargs, # NOOP kwargs, for now
453453
) -> Union[Tuple, BaseModelOutputWithPast]:
454454
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
455455
output_hidden_states = (
@@ -741,6 +741,7 @@ def forward(
741741
output_hidden_states: Optional[bool] = None,
742742
return_dict: Optional[bool] = None,
743743
cache_position: Optional[torch.LongTensor] = None,
744+
**kwargs,
744745
) -> Union[Tuple, CausalLMOutputWithPast]:
745746
r"""
746747
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -775,12 +776,13 @@ def forward(
775776
if labels is not None:
776777
# move labels to correct device to enable model parallelism
777778
labels = labels.to(lm_logits.device)
778-
# Shift so that tokens < n predict n
779-
shift_logits = lm_logits[..., :-1, :].contiguous()
780-
shift_labels = labels[..., 1:].contiguous()
781779
# Flatten the tokens
782-
loss_fct = CrossEntropyLoss()
783-
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
780+
loss = self.loss_function(
781+
lm_logits,
782+
labels,
783+
vocab_size=self.config.vocab_size,
784+
**kwargs,
785+
)
784786

785787
loss = loss.to(hidden_states.dtype)
786788

src/transformers/models/ctrl/modeling_ctrl.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ def forward(
360360
output_attentions: Optional[bool] = None,
361361
output_hidden_states: Optional[bool] = None,
362362
return_dict: Optional[bool] = None,
363+
**kwargs, # NOOP kwargs, for now
363364
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
364365
r"""
365366
Returns:
@@ -537,6 +538,7 @@ def forward(
537538
output_attentions: Optional[bool] = None,
538539
output_hidden_states: Optional[bool] = None,
539540
return_dict: Optional[bool] = None,
541+
**kwargs,
540542
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
541543
r"""
542544
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -593,12 +595,12 @@ def forward(
593595

594596
loss = None
595597
if labels is not None:
596-
# Shift so that tokens < n predict n
597-
shift_logits = lm_logits[..., :-1, :].contiguous()
598-
shift_labels = labels[..., 1:].contiguous()
599-
# Flatten the tokens
600-
loss_fct = CrossEntropyLoss()
601-
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
598+
loss = self.loss_function(
599+
lm_logits,
600+
labels,
601+
vocab_size=self.config.vocab_size,
602+
**kwargs,
603+
)
602604

603605
if not return_dict:
604606
output = (lm_logits,) + transformer_outputs[1:]

src/transformers/models/data2vec/modeling_data2vec_text.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,7 @@ def forward(
906906
output_attentions: Optional[bool] = None,
907907
output_hidden_states: Optional[bool] = None,
908908
return_dict: Optional[bool] = None,
909+
**kwargs,
909910
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
910911
r"""
911912
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -975,13 +976,12 @@ def forward(
975976

976977
lm_loss = None
977978
if labels is not None:
978-
# we are doing next-token prediction; shift prediction scores and input ids by one
979-
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
980-
labels = labels[:, 1:].contiguous()
981-
loss_fct = CrossEntropyLoss()
982-
983-
labels = labels.to(shifted_prediction_scores.device)
984-
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
979+
lm_loss = self.loss_function(
980+
prediction_scores,
981+
labels,
982+
vocab_size=self.config.vocab_size,
983+
**kwargs,
984+
)
985985

986986
if not return_dict:
987987
output = (prediction_scores,) + outputs[2:]

src/transformers/models/dbrx/modeling_dbrx.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,7 @@ def forward(
977977
output_router_logits: Optional[bool] = None,
978978
return_dict: Optional[bool] = None,
979979
cache_position: Optional[torch.LongTensor] = None,
980+
**kwargs, # NOOP kwargs, for now
980981
) -> Union[Tuple, MoeModelOutputWithPast]:
981982
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
982983
output_hidden_states = (
@@ -1274,6 +1275,7 @@ def forward(
12741275
return_dict: Optional[bool] = None,
12751276
cache_position: Optional[torch.LongTensor] = None,
12761277
num_logits_to_keep: int = 0,
1278+
**kwargs,
12771279
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
12781280
r"""Forward function for causal language modeling.
12791281
@@ -1337,16 +1339,12 @@ def forward(
13371339

13381340
loss = None
13391341
if labels is not None:
1340-
# Shift so that tokens < n predict n
1341-
shift_logits = logits[..., :-1, :].contiguous()
1342-
shift_labels = labels[..., 1:].contiguous()
1343-
# Flatten the tokens
1344-
loss_fct = nn.CrossEntropyLoss()
1345-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
1346-
shift_labels = shift_labels.view(-1)
1347-
# Enable model parallelism
1348-
shift_labels = shift_labels.to(shift_logits.device)
1349-
loss = loss_fct(shift_logits, shift_labels)
1342+
loss = self.loss_function(
1343+
logits,
1344+
labels,
1345+
vocab_size=self.config.vocab_size,
1346+
**kwargs,
1347+
)
13501348

13511349
aux_loss = None
13521350
if output_router_logits:

src/transformers/models/electra/modeling_electra.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,6 +1564,7 @@ def forward(
15641564
output_attentions: Optional[bool] = None,
15651565
output_hidden_states: Optional[bool] = None,
15661566
return_dict: Optional[bool] = None,
1567+
**kwargs,
15671568
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
15681569
r"""
15691570
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -1633,11 +1634,12 @@ def forward(
16331634

16341635
lm_loss = None
16351636
if labels is not None:
1636-
# we are doing next-token prediction; shift prediction scores and input ids by one
1637-
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1638-
labels = labels[:, 1:].contiguous()
1639-
loss_fct = CrossEntropyLoss()
1640-
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1637+
lm_loss = self.loss_function(
1638+
prediction_scores,
1639+
labels,
1640+
vocab_size=self.config.vocab_size,
1641+
**kwargs,
1642+
)
16411643

16421644
if not return_dict:
16431645
output = (prediction_scores,) + outputs[1:]

0 commit comments

Comments
 (0)