Skip to content

Commit 0f7853c

Browse files
authored
enable loading hf llama checkpoints for training (PaddlePaddle#446)
* prelim. * add hf convertion fn. * mlp. * change name. * fix bug. * inverse permute. * change comment. * revert style changes. * fix. * add doc. * revert. * enable load safe. * fix safe load. * fix import. * fix typing-related lints. * fix ckpt loading logic. * make single gpu work. * test with parallel. * ckpt format. * enable pretrained state dict. * remove unused imports. * remove unused. * mark idea related.
1 parent c60851a commit 0f7853c

File tree

4 files changed

+182
-42
lines changed

4 files changed

+182
-42
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ var/
1919
*.egg-info/
2020
.installed.cfg
2121
*.egg
22+
23+
# IDE-related
24+
.idea/

flash_attn/models/llama.py

Lines changed: 100 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# Copyright (c) 2023, Tri Dao.
22

3-
import math
43
import json
4+
import math
5+
import os
56
import re
6-
from pathlib import Path
7-
87
from collections import OrderedDict
8+
from pathlib import Path
9+
from typing import Union
910

1011
import torch
1112
import torch.nn.functional as F
12-
1313
from transformers import GPT2Config, LlamaConfig
1414

1515

@@ -74,10 +74,91 @@ def key_mapping_attn(key):
7474
r'transformer.layers.\1.mixer.out_proj.', key)
7575
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
7676

77+
state_dict.pop("transformer.rope.freqs", None)
78+
7779
return state_dict
7880

7981

80-
def config_from_checkpoint(checkpoint_path: str, model_name: str) -> LlamaConfig:
82+
def remap_state_dict_hf_llama(state_dict, config):
83+
# Embedding
84+
def key_mapping_emb(key):
85+
return re.sub(r'^model.embed_tokens.', 'transformer.embeddings.word_embeddings.', key)
86+
87+
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
88+
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
89+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
90+
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
91+
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
92+
* pad_vocab_size_multiple)
93+
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
94+
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
95+
)
96+
97+
# LM head
98+
if getattr(config, 'tie_word_embeddings'):
99+
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
100+
else:
101+
output_embeddings = state_dict.pop('lm_head.weight')
102+
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
103+
# differently.
104+
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
105+
* pad_vocab_size_multiple)
106+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
107+
state_dict['lm_head.weight'] = F.pad(
108+
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
109+
)
110+
111+
# MLP
112+
for l in range(config.n_layer):
113+
# Fusing weights this way based on difference in the following:
114+
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
115+
# https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
116+
w1 = state_dict.pop(f'model.layers.{l}.mlp.gate_proj.weight')
117+
w3 = state_dict.pop(f'model.layers.{l}.mlp.up_proj.weight')
118+
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0)
119+
120+
def key_mapping_mlp(key):
121+
return re.sub(r'^model.layers.(\d+).mlp.down_proj.',
122+
r'transformer.layers.\1.mlp.fc2.', key)
123+
124+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
125+
126+
# LayerNorm
127+
def key_mapping_ln(key):
128+
key = re.sub(r'^model.norm.', r'transformer.ln_f.', key)
129+
key = re.sub(r'^model.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key)
130+
key = re.sub(r'^model.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key)
131+
return key
132+
133+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
134+
135+
def inv_permute(w):
136+
# Inverse of permute implemented in:
137+
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
138+
return w.reshape(
139+
config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd
140+
).transpose(1, 2).reshape(config.n_embd, config.n_embd)
141+
142+
# Attention
143+
for l in range(config.n_layer):
144+
Wq = state_dict.pop(f'model.layers.{l}.self_attn.q_proj.weight')
145+
Wk = state_dict.pop(f'model.layers.{l}.self_attn.k_proj.weight')
146+
Wv = state_dict.pop(f'model.layers.{l}.self_attn.v_proj.weight')
147+
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat(
148+
[inv_permute(Wq), inv_permute(Wk), Wv], dim=0
149+
)
150+
# We don't store these
151+
state_dict.pop(f'model.layers.{l}.self_attn.rotary_emb.inv_freq', None)
152+
153+
def key_mapping_attn(key):
154+
return re.sub(r'^model.layers.(\d+).self_attn.o_proj.',
155+
r'transformer.layers.\1.mixer.out_proj.', key)
156+
157+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
158+
return state_dict
159+
160+
161+
def config_from_meta_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig:
81162
"""Load a LlamaConfig from a checkpoint path."""
82163
with open(Path(checkpoint_path) / model_name / 'params.json') as f:
83164
params = json.load(f)
@@ -88,7 +169,20 @@ def config_from_checkpoint(checkpoint_path: str, model_name: str) -> LlamaConfig
88169
return config
89170

90171

91-
def state_dicts_from_checkpoint(checkpoint_path: str, model_name: str) -> dict:
172+
def config_from_hf_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig:
173+
return LlamaConfig.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf' / "config.json")
174+
175+
176+
def config_from_checkpoint(
177+
checkpoint_path: Union[str, os.PathLike], model_name: str, checkpoint_format="meta"
178+
) -> LlamaConfig:
179+
if checkpoint_format == "meta":
180+
return config_from_meta_checkpoint(checkpoint_path, model_name)
181+
else:
182+
return config_from_hf_checkpoint(checkpoint_path, model_name)
183+
184+
185+
def state_dicts_from_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> list[dict]:
92186
# Need to sort, otherwise we mess up the ordering and the weights are wrong
93187
return [torch.load(path, map_location='cpu')
94188
for path in sorted((Path(checkpoint_path) / model_name).glob('consolidated.*.pth'))]

flash_attn/utils/pretrained.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,49 @@
1-
import torch
1+
import os
2+
from functools import partial
23

3-
from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
4-
from transformers.utils import is_remote_url
5-
from transformers.modeling_utils import load_state_dict
4+
import torch
5+
from safetensors.torch import load_file as safe_load_file
6+
from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
67
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
78

89

910
def state_dict_from_pretrained(model_name, device=None, dtype=None):
1011
# If not fp32, then we don't want to load directly to the GPU
1112
mapped_device = 'cpu' if dtype not in [torch.float32, None] else device
1213
is_sharded = False
13-
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
14-
_raise_exceptions_for_missing_entries=False)
15-
if resolved_archive_file is None:
14+
load_safe = False
15+
resolved_archive_file = None
16+
17+
weights_path = os.path.join(model_name, WEIGHTS_NAME)
18+
weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME)
19+
safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME)
20+
safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
21+
22+
if os.path.isfile(weights_path):
23+
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
24+
_raise_exceptions_for_missing_entries=False)
25+
elif os.path.isfile(weights_index_path):
1626
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME,
1727
_raise_exceptions_for_missing_entries=False)
18-
if resolved_archive_file is not None:
19-
is_sharded = True
28+
is_sharded = True
29+
elif os.path.isfile(safe_weights_path):
30+
resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME,
31+
_raise_exceptions_for_missing_entries=False)
32+
load_safe = True
33+
elif os.path.isfile(safe_weights_index_path):
34+
resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME,
35+
_raise_exceptions_for_missing_entries=False)
36+
is_sharded = True
37+
load_safe = True
38+
2039
if resolved_archive_file is None:
2140
raise EnvironmentError(f"Model name {model_name} was not found.")
41+
42+
if load_safe:
43+
loader = partial(safe_load_file, device=mapped_device)
44+
else:
45+
loader = partial(torch.load, map_location=mapped_device)
46+
2247
if is_sharded:
2348
# resolved_archive_file becomes a list of files that point to the different
2449
# checkpoint shards in this case.
@@ -27,9 +52,9 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
2752
)
2853
state_dict = {}
2954
for sharded_file in resolved_archive_file:
30-
state_dict.update(torch.load(sharded_file, map_location=mapped_device))
55+
state_dict.update(loader(sharded_file))
3156
else:
32-
state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
57+
state_dict = loader(resolved_archive_file)
3358
# Convert dtype before moving to GPU to save memory
3459
if dtype is not None:
3560
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}

tests/models/test_llama.py

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,36 @@
88
import os
99
import time
1010
from pathlib import Path
11+
1112
current_dir = Path(__file__).parent.absolute()
1213

1314
import torch
1415
import pytest
1516

1617
from einops import rearrange
1718

18-
from transformers import LlamaConfig, LlamaTokenizer
19+
from transformers import LlamaTokenizer
1920
from transformers.models.llama.modeling_llama import LlamaForCausalLM
2021

2122
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
22-
from flash_attn.models.llama import remap_state_dict_meta_llama, llama_config_to_gpt2_config
23+
from flash_attn.models.llama import remap_state_dict_meta_llama, llama_config_to_gpt2_config, remap_state_dict_hf_llama
2324
from flash_attn.models.llama import config_from_checkpoint, state_dicts_from_checkpoint
2425
from flash_attn.utils.distributed import all_gather_raw
2526
from flash_attn.utils.pretrained import state_dict_from_pretrained
2627
from flash_attn.utils.generation import update_graph_cache
2728

2829

30+
def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format):
31+
if checkpoint_format == "meta":
32+
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
33+
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
34+
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
35+
else:
36+
pretrained_state_dict = state_dict_from_pretrained(Path(checkpoint_path) / f'{model_name}-hf')
37+
pretrained_state_dict = remap_state_dict_hf_llama(pretrained_state_dict, config)
38+
return pretrained_state_dict
39+
40+
2941
@pytest.mark.parametrize('model_name', ["7B"])
3042
def test_llama_state_dict(model_name):
3143
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
@@ -41,8 +53,8 @@ def test_llama_state_dict(model_name):
4153

4254

4355
@pytest.mark.parametrize('model_name', ["7B", "13B"])
44-
# @pytest.mark.parametrize('model_name', ["7B"])
45-
def test_llama_optimized(model_name):
56+
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
57+
def test_llama_optimized(model_name, checkpoint_format):
4658
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
4759
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
4860
forward pass in fp16, when compared to the HF forward pass in fp32.
@@ -52,16 +64,17 @@ def test_llama_optimized(model_name):
5264

5365
dtype = torch.float16
5466
device = 'cuda'
55-
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
67+
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
68+
config = llama_config_to_gpt2_config(config)
5669
config.use_flash_attn = True
5770
config.fused_bias_fc = True
5871
config.fused_mlp = False # We don't have fused GatedMLP yet
5972
config.fused_dropout_add_ln = True
6073
config.residual_in_fp32 = True
6174

62-
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
63-
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
64-
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
75+
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
76+
checkpoint_path, model_name, config, checkpoint_format
77+
)
6578
model = GPTLMHeadModel(config, device=device, dtype=dtype)
6679
model.load_state_dict(pretrained_state_dict)
6780
model.eval()
@@ -111,7 +124,8 @@ def test_llama_optimized(model_name):
111124
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
112125
@pytest.mark.parametrize('world_size', [2])
113126
@pytest.mark.parametrize('model_name', ["13B"])
114-
def test_llama_parallel(model_name, world_size):
127+
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
128+
def test_llama_parallel(model_name, world_size, checkpoint_format):
115129
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
116130
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
117131
forward pass in fp16, when compared to the HF forward pass in fp32.
@@ -122,7 +136,8 @@ def test_llama_parallel(model_name, world_size):
122136
current_dir.parent.parent / 'checkpoints')) / 'llama'
123137

124138
dtype = torch.float16
125-
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
139+
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
140+
config = llama_config_to_gpt2_config(config)
126141
config.use_flash_attn = True
127142
config.fused_bias_fc = True
128143
config.fused_mlp = False # We don't have fused GatedMLP yet
@@ -137,10 +152,9 @@ def test_llama_parallel(model_name, world_size):
137152
rank = parallel_state.get_tensor_model_parallel_rank()
138153
process_group = parallel_state.get_tensor_model_parallel_group()
139154

140-
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
141-
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
142-
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
143-
155+
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
156+
checkpoint_path, model_name, config, checkpoint_format
157+
)
144158
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
145159
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
146160
model.eval()
@@ -196,13 +210,15 @@ def test_llama_parallel(model_name, world_size):
196210

197211
# @pytest.mark.parametrize('model_name', ["7B", "13B"])
198212
@pytest.mark.parametrize('model_name', ["7B"])
199-
def test_llama_generation(model_name):
213+
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
214+
def test_llama_generation(model_name, checkpoint_format):
200215
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
201216
current_dir.parent.parent / 'checkpoints')) / 'llama'
202217

203218
dtype = torch.float16
204219
device = 'cuda'
205-
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
220+
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
221+
config = llama_config_to_gpt2_config(config)
206222
config.use_flash_attn = True
207223
config.fused_bias_fc = True
208224
config.fused_mlp = False # We don't have fused GatedMLP yet
@@ -239,9 +255,10 @@ def test_llama_generation(model_name):
239255
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1].to(device=device)
240256
del model_ref
241257

242-
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
243-
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
244-
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
258+
259+
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
260+
checkpoint_path, model_name, config, checkpoint_format
261+
)
245262
model = GPTLMHeadModel(config, device=device, dtype=dtype)
246263
model.load_state_dict(pretrained_state_dict)
247264
model.eval()
@@ -291,7 +308,8 @@ def test_llama_generation(model_name):
291308
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
292309
@pytest.mark.parametrize('world_size', [2])
293310
@pytest.mark.parametrize('model_name', ["13B"])
294-
def test_llama_parallel_generation(model_name, world_size):
311+
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
312+
def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
295313
"""Check that our implementation matches the HF implementation:
296314
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
297315
the HF scores in fp32.
@@ -302,7 +320,8 @@ def test_llama_parallel_generation(model_name, world_size):
302320
current_dir.parent.parent / 'checkpoints')) / 'llama'
303321

304322
dtype = torch.float16
305-
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
323+
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
324+
config = llama_config_to_gpt2_config(config)
306325
config.use_flash_attn = False
307326
config.fused_bias_fc = True
308327
config.fused_mlp = False # We don't have fused GatedMLP yet
@@ -331,10 +350,9 @@ def test_llama_parallel_generation(model_name, world_size):
331350
# GPU0 and GPU1 and things would hang
332351
torch.cuda.set_device(device)
333352

334-
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
335-
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
336-
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
337-
353+
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
354+
checkpoint_path, model_name, config, checkpoint_format
355+
)
338356
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
339357
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
340358
model.eval()

0 commit comments

Comments
 (0)