8
8
import os
9
9
import time
10
10
from pathlib import Path
11
+
11
12
current_dir = Path (__file__ ).parent .absolute ()
12
13
13
14
import torch
14
15
import pytest
15
16
16
17
from einops import rearrange
17
18
18
- from transformers import LlamaConfig , LlamaTokenizer
19
+ from transformers import LlamaTokenizer
19
20
from transformers .models .llama .modeling_llama import LlamaForCausalLM
20
21
21
22
from flash_attn .models .gpt import GPTLMHeadModel , combine_state_dicts_tp , shard_state_dict_tp
22
- from flash_attn .models .llama import remap_state_dict_meta_llama , llama_config_to_gpt2_config
23
+ from flash_attn .models .llama import remap_state_dict_meta_llama , llama_config_to_gpt2_config , remap_state_dict_hf_llama
23
24
from flash_attn .models .llama import config_from_checkpoint , state_dicts_from_checkpoint
24
25
from flash_attn .utils .distributed import all_gather_raw
25
26
from flash_attn .utils .pretrained import state_dict_from_pretrained
26
27
from flash_attn .utils .generation import update_graph_cache
27
28
28
29
30
+ def _pretrained_state_dict_from_checkpoint (checkpoint_path , model_name , config , checkpoint_format ):
31
+ if checkpoint_format == "meta" :
32
+ ckpt_state_dicts = state_dicts_from_checkpoint (checkpoint_path , model_name )
33
+ pretrained_state_dicts = [remap_state_dict_meta_llama (s , config ) for s in ckpt_state_dicts ]
34
+ pretrained_state_dict = combine_state_dicts_tp (pretrained_state_dicts , config )
35
+ else :
36
+ pretrained_state_dict = state_dict_from_pretrained (Path (checkpoint_path ) / f'{ model_name } -hf' )
37
+ pretrained_state_dict = remap_state_dict_hf_llama (pretrained_state_dict , config )
38
+ return pretrained_state_dict
39
+
40
+
29
41
@pytest .mark .parametrize ('model_name' , ["7B" ])
30
42
def test_llama_state_dict (model_name ):
31
43
checkpoint_path = Path (os .environ .get ('CHECKPOINT_DIR' ,
@@ -41,8 +53,8 @@ def test_llama_state_dict(model_name):
41
53
42
54
43
55
@pytest .mark .parametrize ('model_name' , ["7B" , "13B" ])
44
- # @pytest.mark.parametrize('model_name ', ["7B "])
45
- def test_llama_optimized (model_name ):
56
+ @pytest .mark .parametrize ('checkpoint_format ' , ["meta" , "hf " ])
57
+ def test_llama_optimized (model_name , checkpoint_format ):
46
58
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
47
59
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
48
60
forward pass in fp16, when compared to the HF forward pass in fp32.
@@ -52,16 +64,17 @@ def test_llama_optimized(model_name):
52
64
53
65
dtype = torch .float16
54
66
device = 'cuda'
55
- config = llama_config_to_gpt2_config (config_from_checkpoint (checkpoint_path , model_name ))
67
+ config = config_from_checkpoint (checkpoint_path , model_name , checkpoint_format )
68
+ config = llama_config_to_gpt2_config (config )
56
69
config .use_flash_attn = True
57
70
config .fused_bias_fc = True
58
71
config .fused_mlp = False # We don't have fused GatedMLP yet
59
72
config .fused_dropout_add_ln = True
60
73
config .residual_in_fp32 = True
61
74
62
- ckpt_state_dicts = state_dicts_from_checkpoint ( checkpoint_path , model_name )
63
- pretrained_state_dicts = [ remap_state_dict_meta_llama ( s , config ) for s in ckpt_state_dicts ]
64
- pretrained_state_dict = combine_state_dicts_tp ( pretrained_state_dicts , config )
75
+ pretrained_state_dict = _pretrained_state_dict_from_checkpoint (
76
+ checkpoint_path , model_name , config , checkpoint_format
77
+ )
65
78
model = GPTLMHeadModel (config , device = device , dtype = dtype )
66
79
model .load_state_dict (pretrained_state_dict )
67
80
model .eval ()
@@ -111,7 +124,8 @@ def test_llama_optimized(model_name):
111
124
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
112
125
@pytest .mark .parametrize ('world_size' , [2 ])
113
126
@pytest .mark .parametrize ('model_name' , ["13B" ])
114
- def test_llama_parallel (model_name , world_size ):
127
+ @pytest .mark .parametrize ('checkpoint_format' , ["meta" , "hf" ])
128
+ def test_llama_parallel (model_name , world_size , checkpoint_format ):
115
129
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
116
130
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
117
131
forward pass in fp16, when compared to the HF forward pass in fp32.
@@ -122,7 +136,8 @@ def test_llama_parallel(model_name, world_size):
122
136
current_dir .parent .parent / 'checkpoints' )) / 'llama'
123
137
124
138
dtype = torch .float16
125
- config = llama_config_to_gpt2_config (config_from_checkpoint (checkpoint_path , model_name ))
139
+ config = config_from_checkpoint (checkpoint_path , model_name , checkpoint_format )
140
+ config = llama_config_to_gpt2_config (config )
126
141
config .use_flash_attn = True
127
142
config .fused_bias_fc = True
128
143
config .fused_mlp = False # We don't have fused GatedMLP yet
@@ -137,10 +152,9 @@ def test_llama_parallel(model_name, world_size):
137
152
rank = parallel_state .get_tensor_model_parallel_rank ()
138
153
process_group = parallel_state .get_tensor_model_parallel_group ()
139
154
140
- ckpt_state_dicts = state_dicts_from_checkpoint (checkpoint_path , model_name )
141
- pretrained_state_dicts = [remap_state_dict_meta_llama (s , config ) for s in ckpt_state_dicts ]
142
- pretrained_state_dict = combine_state_dicts_tp (pretrained_state_dicts , config )
143
-
155
+ pretrained_state_dict = _pretrained_state_dict_from_checkpoint (
156
+ checkpoint_path , model_name , config , checkpoint_format
157
+ )
144
158
model = GPTLMHeadModel (config , process_group = process_group , device = device , dtype = dtype )
145
159
model .load_state_dict (shard_state_dict_tp (pretrained_state_dict , config , world_size , rank ))
146
160
model .eval ()
@@ -196,13 +210,15 @@ def test_llama_parallel(model_name, world_size):
196
210
197
211
# @pytest.mark.parametrize('model_name', ["7B", "13B"])
198
212
@pytest .mark .parametrize ('model_name' , ["7B" ])
199
- def test_llama_generation (model_name ):
213
+ @pytest .mark .parametrize ('checkpoint_format' , ["meta" , "hf" ])
214
+ def test_llama_generation (model_name , checkpoint_format ):
200
215
checkpoint_path = Path (os .environ .get ('CHECKPOINT_DIR' ,
201
216
current_dir .parent .parent / 'checkpoints' )) / 'llama'
202
217
203
218
dtype = torch .float16
204
219
device = 'cuda'
205
- config = llama_config_to_gpt2_config (config_from_checkpoint (checkpoint_path , model_name ))
220
+ config = config_from_checkpoint (checkpoint_path , model_name , checkpoint_format )
221
+ config = llama_config_to_gpt2_config (config )
206
222
config .use_flash_attn = True
207
223
config .fused_bias_fc = True
208
224
config .fused_mlp = False # We don't have fused GatedMLP yet
@@ -239,9 +255,10 @@ def test_llama_generation(model_name):
239
255
logits_ref = model_ref (out_hf .sequences ).logits [:, (seqlen - 1 ):- 1 ].to (device = device )
240
256
del model_ref
241
257
242
- ckpt_state_dicts = state_dicts_from_checkpoint (checkpoint_path , model_name )
243
- pretrained_state_dicts = [remap_state_dict_meta_llama (s , config ) for s in ckpt_state_dicts ]
244
- pretrained_state_dict = combine_state_dicts_tp (pretrained_state_dicts , config )
258
+
259
+ pretrained_state_dict = _pretrained_state_dict_from_checkpoint (
260
+ checkpoint_path , model_name , config , checkpoint_format
261
+ )
245
262
model = GPTLMHeadModel (config , device = device , dtype = dtype )
246
263
model .load_state_dict (pretrained_state_dict )
247
264
model .eval ()
@@ -291,7 +308,8 @@ def test_llama_generation(model_name):
291
308
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
292
309
@pytest .mark .parametrize ('world_size' , [2 ])
293
310
@pytest .mark .parametrize ('model_name' , ["13B" ])
294
- def test_llama_parallel_generation (model_name , world_size ):
311
+ @pytest .mark .parametrize ('checkpoint_format' , ["meta" , "hf" ])
312
+ def test_llama_parallel_generation (model_name , world_size , checkpoint_format ):
295
313
"""Check that our implementation matches the HF implementation:
296
314
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
297
315
the HF scores in fp32.
@@ -302,7 +320,8 @@ def test_llama_parallel_generation(model_name, world_size):
302
320
current_dir .parent .parent / 'checkpoints' )) / 'llama'
303
321
304
322
dtype = torch .float16
305
- config = llama_config_to_gpt2_config (config_from_checkpoint (checkpoint_path , model_name ))
323
+ config = config_from_checkpoint (checkpoint_path , model_name , checkpoint_format )
324
+ config = llama_config_to_gpt2_config (config )
306
325
config .use_flash_attn = False
307
326
config .fused_bias_fc = True
308
327
config .fused_mlp = False # We don't have fused GatedMLP yet
@@ -331,10 +350,9 @@ def test_llama_parallel_generation(model_name, world_size):
331
350
# GPU0 and GPU1 and things would hang
332
351
torch .cuda .set_device (device )
333
352
334
- ckpt_state_dicts = state_dicts_from_checkpoint (checkpoint_path , model_name )
335
- pretrained_state_dicts = [remap_state_dict_meta_llama (s , config ) for s in ckpt_state_dicts ]
336
- pretrained_state_dict = combine_state_dicts_tp (pretrained_state_dicts , config )
337
-
353
+ pretrained_state_dict = _pretrained_state_dict_from_checkpoint (
354
+ checkpoint_path , model_name , config , checkpoint_format
355
+ )
338
356
model = GPTLMHeadModel (config , process_group = process_group , device = device , dtype = dtype )
339
357
model .load_state_dict (shard_state_dict_tp (pretrained_state_dict , config , world_size , rank ))
340
358
model .eval ()
0 commit comments