@@ -129,6 +129,7 @@ def _compute_dynamic_ntk_parameters(
129
129
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
130
130
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
131
131
"""
132
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
132
133
if config is not None and len (rope_kwargs ) > 0 :
133
134
raise ValueError (
134
135
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
@@ -249,6 +250,7 @@ def _compute_longrope_parameters(
249
250
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
250
251
post-processing scaling factor applied to the computed cos/sin.
251
252
"""
253
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
252
254
# No need to keep BC with longrope, unreleased when this new pattern was created.
253
255
if len (rope_kwargs ) > 0 :
254
256
raise ValueError (
@@ -293,6 +295,50 @@ def _compute_longrope_parameters(
293
295
return inv_freq , attention_factor
294
296
295
297
298
+ def _compute_llama3_parameters (
299
+ config : PretrainedConfig , device : "torch.device" , seq_len : Optional [int ] = None , ** rope_kwargs
300
+ ) -> Tuple ["torch.Tensor" , float ]:
301
+ """
302
+ Computes the inverse frequencies for llama 3.1.
303
+
304
+ Args:
305
+ config ([`~transformers.PretrainedConfig`]):
306
+ The model configuration.
307
+ device (`torch.device`):
308
+ The device to use for initialization of the inverse frequencies.
309
+ seq_len (`int`, *optional*):
310
+ The current sequence length. Unused for this type of RoPE.
311
+ rope_kwargs (`Dict`, *optional*):
312
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
313
+ Returns:
314
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
315
+ post-processing scaling factor applied to the computed cos/sin.
316
+ """
317
+ # Gets the default RoPE parameters
318
+ inv_freq , attention_factor = _compute_default_rope_parameters (config , device , seq_len , ** rope_kwargs )
319
+
320
+ factor = config .rope_scaling ["factor" ] # `8` in the original implementation
321
+ low_freq_factor = config .rope_scaling ["low_freq_factor" ] # `1` in the original implementation
322
+ high_freq_factor = config .rope_scaling ["high_freq_factor" ] # `4` in the original implementation
323
+ old_context_len = config .rope_scaling ["original_max_position_embeddings" ] # `8192` in the original implementation
324
+
325
+ low_freq_wavelen = old_context_len / low_freq_factor
326
+ high_freq_wavelen = old_context_len / high_freq_factor
327
+ new_freqs = []
328
+ for freq in inv_freq :
329
+ wavelen = 2 * math .pi / freq
330
+ if wavelen < high_freq_wavelen :
331
+ new_freqs .append (freq )
332
+ elif wavelen > low_freq_wavelen :
333
+ new_freqs .append (freq / factor )
334
+ else :
335
+ assert low_freq_wavelen != high_freq_wavelen
336
+ smooth = (old_context_len / wavelen - low_freq_factor ) / (high_freq_factor - low_freq_factor )
337
+ new_freqs .append ((1 - smooth ) * freq / factor + smooth * freq )
338
+ inv_freq = torch .tensor (new_freqs , dtype = inv_freq .dtype , device = inv_freq .device )
339
+ return inv_freq , attention_factor
340
+
341
+
296
342
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
297
343
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
298
344
# parameterizations, as long as the callable has the same signature.
@@ -302,6 +348,7 @@ def _compute_longrope_parameters(
302
348
"dynamic" : _compute_dynamic_ntk_parameters ,
303
349
"yarn" : _compute_yarn_parameters ,
304
350
"longrope" : _compute_longrope_parameters ,
351
+ "llama3" : _compute_llama3_parameters ,
305
352
}
306
353
307
354
@@ -339,6 +386,20 @@ def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
339
386
raise ValueError (f"`rope_scaling`'s factor field must be a float >= 1, got { factor } " )
340
387
341
388
389
+ def _validate_dynamic_scaling_rope_parameters (config : PretrainedConfig ):
390
+ rope_scaling = config .rope_scaling
391
+ rope_type = rope_scaling ["rope_type" ]
392
+ required_keys = {"rope_type" , "factor" }
393
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
394
+ optional_keys = {"original_max_position_embeddings" }
395
+ received_keys = set (rope_scaling .keys ())
396
+ _check_received_keys (rope_type , received_keys , required_keys , optional_keys )
397
+
398
+ factor = rope_scaling ["factor" ]
399
+ if factor is None or not isinstance (factor , float ) or factor < 1.0 :
400
+ raise ValueError (f"`rope_scaling`'s factor field must be a float >= 1, got { factor } " )
401
+
402
+
342
403
def _validate_yarn_parameters (config : PretrainedConfig ):
343
404
rope_scaling = config .rope_scaling
344
405
rope_type = rope_scaling ["rope_type" ]
@@ -374,7 +435,8 @@ def _validate_longrope_parameters(config: PretrainedConfig):
374
435
rope_scaling = config .rope_scaling
375
436
rope_type = rope_scaling ["rope_type" ]
376
437
required_keys = {"rope_type" , "short_factor" , "long_factor" }
377
- optional_keys = {"attention_factor" , "factor" }
438
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
439
+ optional_keys = {"attention_factor" , "factor" , "original_max_position_embeddings" }
378
440
received_keys = set (rope_scaling .keys ())
379
441
_check_received_keys (rope_type , received_keys , required_keys , optional_keys )
380
442
@@ -417,13 +479,50 @@ def _validate_longrope_parameters(config: PretrainedConfig):
417
479
)
418
480
419
481
482
+ def _validate_llama3_parameters (config : PretrainedConfig ):
483
+ rope_scaling = config .rope_scaling
484
+ rope_type = rope_scaling ["rope_type" ]
485
+ required_keys = {"rope_type" , "factor" , "original_max_position_embeddings" , "low_freq_factor" , "high_freq_factor" }
486
+ received_keys = set (rope_scaling .keys ())
487
+ _check_received_keys (rope_type , received_keys , required_keys )
488
+
489
+ factor = rope_scaling ["factor" ]
490
+ if factor is None or not isinstance (factor , float ) or factor < 1.0 :
491
+ raise ValueError (f"`rope_scaling`'s factor field must be a float >= 1, got { factor } " )
492
+
493
+ low_freq_factor = rope_scaling ["low_freq_factor" ]
494
+ high_freq_factor = rope_scaling ["high_freq_factor" ]
495
+ if low_freq_factor is None or not isinstance (low_freq_factor , float ):
496
+ raise ValueError (f"`rope_scaling`'s low_freq_factor field must be a float, got { low_freq_factor } " )
497
+ if high_freq_factor is None or not isinstance (high_freq_factor , float ):
498
+ raise ValueError (f"`rope_scaling`'s high_freq_factor field must be a float, got { high_freq_factor } " )
499
+ if high_freq_factor < low_freq_factor :
500
+ raise ValueError (
501
+ "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
502
+ f"{ high_freq_factor } and low_freq_factor={ low_freq_factor } "
503
+ )
504
+
505
+ original_max_position_embeddings = rope_scaling ["original_max_position_embeddings" ]
506
+ if original_max_position_embeddings is None or not isinstance (original_max_position_embeddings , int ):
507
+ raise ValueError (
508
+ "`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
509
+ f"{ original_max_position_embeddings } "
510
+ )
511
+ if original_max_position_embeddings >= config .max_position_embeddings :
512
+ raise ValueError (
513
+ "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
514
+ f"{ original_max_position_embeddings } and max_position_embeddings={ config .max_position_embeddings } "
515
+ )
516
+
517
+
420
518
# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
421
519
ROPE_VALIDATION_FUNCTIONS = {
422
520
"default" : _validate_default_rope_parameters ,
423
521
"linear" : _validate_linear_scaling_rope_parameters ,
424
- "dynamic" : _validate_linear_scaling_rope_parameters , # `dynamic` has the same validation pattern as `linear`
522
+ "dynamic" : _validate_dynamic_scaling_rope_parameters ,
425
523
"yarn" : _validate_yarn_parameters ,
426
524
"longrope" : _validate_longrope_parameters ,
525
+ "llama3" : _validate_llama3_parameters ,
427
526
}
428
527
429
528
0 commit comments