Skip to content

Commit 26b179c

Browse files
committed
Llama 3.1 conversion
1 parent 7d92009 commit 26b179c

File tree

3 files changed

+228
-59
lines changed

3 files changed

+228
-59
lines changed

src/transformers/modeling_rope_utils.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def _compute_dynamic_ntk_parameters(
129129
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
130130
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
131131
"""
132+
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
132133
if config is not None and len(rope_kwargs) > 0:
133134
raise ValueError(
134135
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
@@ -249,6 +250,7 @@ def _compute_longrope_parameters(
249250
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
250251
post-processing scaling factor applied to the computed cos/sin.
251252
"""
253+
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
252254
# No need to keep BC with longrope, unreleased when this new pattern was created.
253255
if len(rope_kwargs) > 0:
254256
raise ValueError(
@@ -293,6 +295,50 @@ def _compute_longrope_parameters(
293295
return inv_freq, attention_factor
294296

295297

298+
def _compute_llama3_parameters(
299+
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
300+
) -> Tuple["torch.Tensor", float]:
301+
"""
302+
Computes the inverse frequencies for llama 3.1.
303+
304+
Args:
305+
config ([`~transformers.PretrainedConfig`]):
306+
The model configuration.
307+
device (`torch.device`):
308+
The device to use for initialization of the inverse frequencies.
309+
seq_len (`int`, *optional*):
310+
The current sequence length. Unused for this type of RoPE.
311+
rope_kwargs (`Dict`, *optional*):
312+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
313+
Returns:
314+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
315+
post-processing scaling factor applied to the computed cos/sin.
316+
"""
317+
# Gets the default RoPE parameters
318+
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
319+
320+
factor = config.rope_scaling["factor"] # `8` in the original implementation
321+
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
322+
high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
323+
old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
324+
325+
low_freq_wavelen = old_context_len / low_freq_factor
326+
high_freq_wavelen = old_context_len / high_freq_factor
327+
new_freqs = []
328+
for freq in inv_freq:
329+
wavelen = 2 * math.pi / freq
330+
if wavelen < high_freq_wavelen:
331+
new_freqs.append(freq)
332+
elif wavelen > low_freq_wavelen:
333+
new_freqs.append(freq / factor)
334+
else:
335+
assert low_freq_wavelen != high_freq_wavelen
336+
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
337+
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
338+
inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device)
339+
return inv_freq, attention_factor
340+
341+
296342
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
297343
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
298344
# parameterizations, as long as the callable has the same signature.
@@ -302,6 +348,7 @@ def _compute_longrope_parameters(
302348
"dynamic": _compute_dynamic_ntk_parameters,
303349
"yarn": _compute_yarn_parameters,
304350
"longrope": _compute_longrope_parameters,
351+
"llama3": _compute_llama3_parameters,
305352
}
306353

307354

@@ -339,6 +386,20 @@ def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
339386
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
340387

341388

389+
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
390+
rope_scaling = config.rope_scaling
391+
rope_type = rope_scaling["rope_type"]
392+
required_keys = {"rope_type", "factor"}
393+
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
394+
optional_keys = {"original_max_position_embeddings"}
395+
received_keys = set(rope_scaling.keys())
396+
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
397+
398+
factor = rope_scaling["factor"]
399+
if factor is None or not isinstance(factor, float) or factor < 1.0:
400+
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
401+
402+
342403
def _validate_yarn_parameters(config: PretrainedConfig):
343404
rope_scaling = config.rope_scaling
344405
rope_type = rope_scaling["rope_type"]
@@ -374,7 +435,8 @@ def _validate_longrope_parameters(config: PretrainedConfig):
374435
rope_scaling = config.rope_scaling
375436
rope_type = rope_scaling["rope_type"]
376437
required_keys = {"rope_type", "short_factor", "long_factor"}
377-
optional_keys = {"attention_factor", "factor"}
438+
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
439+
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
378440
received_keys = set(rope_scaling.keys())
379441
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
380442

@@ -417,13 +479,50 @@ def _validate_longrope_parameters(config: PretrainedConfig):
417479
)
418480

419481

482+
def _validate_llama3_parameters(config: PretrainedConfig):
483+
rope_scaling = config.rope_scaling
484+
rope_type = rope_scaling["rope_type"]
485+
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
486+
received_keys = set(rope_scaling.keys())
487+
_check_received_keys(rope_type, received_keys, required_keys)
488+
489+
factor = rope_scaling["factor"]
490+
if factor is None or not isinstance(factor, float) or factor < 1.0:
491+
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
492+
493+
low_freq_factor = rope_scaling["low_freq_factor"]
494+
high_freq_factor = rope_scaling["high_freq_factor"]
495+
if low_freq_factor is None or not isinstance(low_freq_factor, float):
496+
raise ValueError(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
497+
if high_freq_factor is None or not isinstance(high_freq_factor, float):
498+
raise ValueError(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
499+
if high_freq_factor < low_freq_factor:
500+
raise ValueError(
501+
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
502+
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
503+
)
504+
505+
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
506+
if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
507+
raise ValueError(
508+
"`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
509+
f"{original_max_position_embeddings}"
510+
)
511+
if original_max_position_embeddings >= config.max_position_embeddings:
512+
raise ValueError(
513+
"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
514+
f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
515+
)
516+
517+
420518
# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
421519
ROPE_VALIDATION_FUNCTIONS = {
422520
"default": _validate_default_rope_parameters,
423521
"linear": _validate_linear_scaling_rope_parameters,
424-
"dynamic": _validate_linear_scaling_rope_parameters, # `dynamic` has the same validation pattern as `linear`
522+
"dynamic": _validate_dynamic_scaling_rope_parameters,
425523
"yarn": _validate_yarn_parameters,
426524
"longrope": _validate_longrope_parameters,
525+
"llama3": _validate_llama3_parameters,
427526
}
428527

429528

src/transformers/models/llama/configuration_llama.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,25 +73,28 @@ class LlamaConfig(PretrainedConfig):
7373
End of stream token id.
7474
pretraining_tp (`int`, *optional*, defaults to 1):
7575
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
76-
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is
77-
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
78-
issue](https://github.com/pytorch/pytorch/issues/76232).
76+
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
77+
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
78+
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
7979
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
8080
Whether to tie weight embeddings
8181
rope_theta (`float`, *optional*, defaults to 10000.0):
8282
The base period of the RoPE embeddings.
8383
rope_scaling (`Dict`, *optional*):
84-
Dictionary containing the scaling configuration for the RoPE embeddings. IMPORTANT: RoPE scaling expects
85-
`max_position_embeddings` to remain unchanged -- some methods, like 'longrope', require the original value
86-
to determine which scaling to apply.
84+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
85+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
86+
accordingly.
8787
Expected contents:
8888
`rope_type` (`str`):
89-
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope'],
90-
with 'default' being the original RoPE implementation.
89+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
90+
'llama3'], with 'default' being the original RoPE implementation.
9191
`factor` (`float`, *optional*):
9292
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
9393
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
94-
`max_position_embeddings`.
94+
original maximum pre-trained length.
95+
`original_max_position_embeddings` (`int`, *optional*):
96+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
97+
pretraining.
9598
`attention_factor` (`float`, *optional*):
9699
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
97100
computation. If unspecified, it defaults to value recommended by the implementation, using the
@@ -104,12 +107,16 @@ class LlamaConfig(PretrainedConfig):
104107
ramp function. If unspecified, it defaults to 1.
105108
`short_factor` (`List[float]`, *optional*):
106109
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
107-
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden
110+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
108111
size divided by the number of attention heads divided by 2
109112
`long_factor` (`List[float]`, *optional*):
110-
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
111-
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden
113+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
114+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
112115
size divided by the number of attention heads divided by 2
116+
`low_freq_factor` (`float`, *optional*):
117+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
118+
`high_freq_factor` (`float`, *optional*):
119+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
113120
attention_bias (`bool`, *optional*, defaults to `False`):
114121
Whether to use a bias in the query, key, value and output projection layers during self-attention.
115122
attention_dropout (`float`, *optional*, defaults to 0.0):

0 commit comments

Comments
 (0)