Skip to content

Commit 23c0add

Browse files
committed
Implement hugging-face cache and expert slice
1 parent 03cf701 commit 23c0add

File tree

1 file changed

+48
-5
lines changed

1 file changed

+48
-5
lines changed

paddlenlp/trainer/utils/load_hf_ckpt.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,54 @@ def prepare_tensor(tensor, dst_shape, *, force_transpose=False):
277277
if len(tensor.shape) != 1:
278278
print("attention same shape not transpose !!!!!!!!!!!!!!!!!!!!!!")
279279
return tensor
280-
if len(tensor.shape) == 2 and paddle.transpose(tensor, perm=[1, 0]).contiguous().shape == dst_shape:
280+
281+
if len(tensor.shape) == 2:
282+
num_experts, hidden_size = tensor.shape
283+
assert hidden_size == dst_shape[0], f"Shape not match: {tensor.shape} {dst_shape}"
284+
if num_experts != dst_shape[1]:
285+
print(f"Slice weight: {tensor.shape} -> {dst_shape}")
286+
tensor = tensor[:dst_shape[1]]
281287
return paddle.transpose(tensor, perm=[1, 0]).contiguous()
282288

283-
print("shape not match here")
289+
if len(tensor.shape) == 1:
290+
print(f"Slice weight: {tensor.shape} -> {dst_shape}")
291+
tensor = tensor[:dst_shape[0]]
292+
return tensor
293+
294+
print("Fatal: shape not match here:", tensor.shape, dst_shape)
284295
sys.exit()
285296

286297

298+
def hf_cache(path):
299+
print('looking up:', path)
300+
import os, time, subprocess
301+
basename = 'lshrun_' + os.path.basename(path)
302+
cache_path = os.path.join('/dev/shm', basename)
303+
lock_path = cache_path + '.lock'
304+
305+
# Case 1: cache exists
306+
if os.path.exists(cache_path):
307+
print('hit cache:', cache_path)
308+
return cache_path
309+
310+
try:
311+
open(lock_path, 'x')
312+
except FileExistsError:
313+
# Case 2: peer is loading
314+
print('waiting peer load:', lock_path)
315+
while os.path.exists(lock_path):
316+
time.sleep(0.1)
317+
print('peer done:', lock_path)
318+
else:
319+
# Case 3: load it ourself
320+
print('copying:', lock_path)
321+
subprocess.run(['cp', path, lock_path], check=True)
322+
subprocess.run(['mv', lock_path, cache_path], check=True)
323+
print('done copy:', lock_path)
324+
325+
return cache_path
326+
327+
287328
def load_huggingface_ckpt(model, huggingface_ckpt_path):
288329
ckpt_pre = huggingface_ckpt_path
289330

@@ -328,8 +369,9 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
328369
check_list = []
329370
print("Start load huggingface ckpt")
330371
for i, filename in enumerate(required_files):
372+
print(f'loading {i + 1}/{len(required_files)}: {filename}')
331373
try:
332-
with safe_open(ckpt_pre + filename, framework="paddle", device="cpu") as f:
374+
with safe_open(hf_cache(ckpt_pre + filename), framework="paddle", device="cpu") as f:
333375
# 加载该文件包含的所有参数
334376
pd_params = file_to_pd_param_name[filename]
335377
for pd_param in pd_params:
@@ -359,12 +401,12 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
359401
if weight_map[hf_name[0]] == filename:
360402
tensor0 = f.get_tensor(hf_name[0])
361403
with safe_open(
362-
ckpt_pre + weight_map[hf_name[1]], framework="paddle", device="cpu"
404+
hf_cache(ckpt_pre + weight_map[hf_name[1]]), framework="paddle", device="cpu"
363405
) as f_other:
364406
tensor1 = f_other.get_tensor(hf_name[1])
365407
else:
366408
with safe_open(
367-
ckpt_pre + weight_map[hf_name[0]], framework="paddle", device="cpu"
409+
hf_cache(ckpt_pre + weight_map[hf_name[0]]), framework="paddle", device="cpu"
368410
) as f_other:
369411
tensor0 = f_other.get_tensor(hf_name[0])
370412
tensor1 = f.get_tensor(hf_name[1])
@@ -376,3 +418,4 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
376418
except Exception as e:
377419
print(f"Error loading {filename}: {str(e)}")
378420
raise
421+
print("End load huggingface ckpt")

0 commit comments

Comments
 (0)