Skip to content

Commit 0fd02e5

Browse files
committed
Implement hugging-face cache and expert slice
1 parent 03cf701 commit 0fd02e5

File tree

1 file changed

+67
-5
lines changed

1 file changed

+67
-5
lines changed

paddlenlp/trainer/utils/load_hf_ckpt.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,73 @@ def prepare_tensor(tensor, dst_shape, *, force_transpose=False):
277277
if len(tensor.shape) != 1:
278278
print("attention same shape not transpose !!!!!!!!!!!!!!!!!!!!!!")
279279
return tensor
280-
if len(tensor.shape) == 2 and paddle.transpose(tensor, perm=[1, 0]).contiguous().shape == dst_shape:
280+
281+
if len(tensor.shape) == 2:
282+
num_experts, hidden_size = tensor.shape
283+
assert hidden_size == dst_shape[0], f"Shape not match: {tensor.shape} {dst_shape}"
284+
if num_experts != dst_shape[1]:
285+
print(f"Slice weight: {tensor.shape} -> {dst_shape}")
286+
tensor = tensor[:dst_shape[1]]
281287
return paddle.transpose(tensor, perm=[1, 0]).contiguous()
282288

283-
print("shape not match here")
289+
if len(tensor.shape) == 1:
290+
print(f"Slice weight: {tensor.shape} -> {dst_shape}")
291+
tensor = tensor[:dst_shape[0]]
292+
return tensor
293+
294+
print("Fatal: shape not match here:", tensor.shape, dst_shape)
284295
sys.exit()
285296

286297

298+
def hf_cache(path):
299+
print('looking up:', path)
300+
import os, time, subprocess
301+
basename = os.path.basename(path)
302+
cache_path = os.path.join('/dev/shm', 'lshrun_' + basename)
303+
disk_cache_path = os.path.join('/root/paddlejob/tmpspace/liangshuhao', basename)
304+
lock_path = cache_path + '.lock'
305+
begin = time.time()
306+
307+
# Case 1: 文件在内存中
308+
if os.path.exists(cache_path):
309+
print('hit mem cache:', cache_path)
310+
return cache_path
311+
312+
# Case 2: 文件在磁盘中
313+
if os.path.exists(disk_cache_path):
314+
print('hit disk cache:', disk_cache_path)
315+
return disk_cache_path
316+
317+
# Case 3: 等待其他进程将文件搬运到内存
318+
try:
319+
open(lock_path, 'x')
320+
except FileExistsError:
321+
print('waiting peer load:', cache_path)
322+
while not os.path.exists(cache_path):
323+
time.sleep(0.1)
324+
print('peer done:', cache_path)
325+
return cache_path
326+
327+
# Case 4: 从其他机器的磁盘中取回
328+
ckpt_id = int(basename.split('-')[1])
329+
dst_rank = ckpt_id % int(os.environ['TRAINERS_NUM'])
330+
dst_ip = os.environ['TRIANER_IP_LIST'].split(',')[dst_rank]
331+
print('fetching:', f'root@{dst_ip}:{disk_cache_path}', '->', lock_path)
332+
if subprocess.run(['scp', f'root@{dst_ip}:{disk_cache_path}', lock_path]).returncode == 0:
333+
subprocess.run(['mv', lock_path, cache_path], check=True)
334+
print(f'done fetch in {time.time() - begin:.3f}s:', cache_path)
335+
return cache_path
336+
337+
# Case 5: 从源地址取回
338+
print('copying:', path, '->', lock_path)
339+
while subprocess.run(['cp', path, lock_path]).returncode:
340+
print('retrying:', path, '->', lock_path)
341+
time.sleep(10) # sometimes too many open files cause error
342+
subprocess.run(['mv', lock_path, cache_path], check=True)
343+
print(f'done copy in {time.time() - begin:.3f}s:', cache_path)
344+
return cache_path
345+
346+
287347
def load_huggingface_ckpt(model, huggingface_ckpt_path):
288348
ckpt_pre = huggingface_ckpt_path
289349

@@ -328,8 +388,9 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
328388
check_list = []
329389
print("Start load huggingface ckpt")
330390
for i, filename in enumerate(required_files):
391+
print(f'loading {i + 1}/{len(required_files)}: {filename}')
331392
try:
332-
with safe_open(ckpt_pre + filename, framework="paddle", device="cpu") as f:
393+
with safe_open(hf_cache(ckpt_pre + filename), framework="paddle", device="cpu") as f:
333394
# 加载该文件包含的所有参数
334395
pd_params = file_to_pd_param_name[filename]
335396
for pd_param in pd_params:
@@ -359,12 +420,12 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
359420
if weight_map[hf_name[0]] == filename:
360421
tensor0 = f.get_tensor(hf_name[0])
361422
with safe_open(
362-
ckpt_pre + weight_map[hf_name[1]], framework="paddle", device="cpu"
423+
hf_cache(ckpt_pre + weight_map[hf_name[1]]), framework="paddle", device="cpu"
363424
) as f_other:
364425
tensor1 = f_other.get_tensor(hf_name[1])
365426
else:
366427
with safe_open(
367-
ckpt_pre + weight_map[hf_name[0]], framework="paddle", device="cpu"
428+
hf_cache(ckpt_pre + weight_map[hf_name[0]]), framework="paddle", device="cpu"
368429
) as f_other:
369430
tensor0 = f_other.get_tensor(hf_name[0])
370431
tensor1 = f.get_tensor(hf_name[1])
@@ -376,3 +437,4 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
376437
except Exception as e:
377438
print(f"Error loading {filename}: {str(e)}")
378439
raise
440+
print("End load huggingface ckpt")

0 commit comments

Comments
 (0)