@@ -277,13 +277,73 @@ def prepare_tensor(tensor, dst_shape, *, force_transpose=False):
277
277
if len (tensor .shape ) != 1 :
278
278
print ("attention same shape not transpose !!!!!!!!!!!!!!!!!!!!!!" )
279
279
return tensor
280
- if len (tensor .shape ) == 2 and paddle .transpose (tensor , perm = [1 , 0 ]).contiguous ().shape == dst_shape :
280
+
281
+ if len (tensor .shape ) == 2 :
282
+ num_experts , hidden_size = tensor .shape
283
+ assert hidden_size == dst_shape [0 ], f"Shape not match: { tensor .shape } { dst_shape } "
284
+ if num_experts != dst_shape [1 ]:
285
+ print (f"Slice weight: { tensor .shape } -> { dst_shape } " )
286
+ tensor = tensor [:dst_shape [1 ]]
281
287
return paddle .transpose (tensor , perm = [1 , 0 ]).contiguous ()
282
288
283
- print ("shape not match here" )
289
+ if len (tensor .shape ) == 1 :
290
+ print (f"Slice weight: { tensor .shape } -> { dst_shape } " )
291
+ tensor = tensor [:dst_shape [0 ]]
292
+ return tensor
293
+
294
+ print ("Fatal: shape not match here:" , tensor .shape , dst_shape )
284
295
sys .exit ()
285
296
286
297
298
+ def hf_cache (path ):
299
+ print ('looking up:' , path )
300
+ import os , time , subprocess
301
+ basename = os .path .basename (path )
302
+ cache_path = os .path .join ('/dev/shm' , 'lshrun_' + basename )
303
+ disk_cache_path = os .path .join ('/root/paddlejob/tmpspace/liangshuhao' , basename )
304
+ lock_path = cache_path + '.lock'
305
+ begin = time .time ()
306
+
307
+ # Case 1: 文件在内存中
308
+ if os .path .exists (cache_path ):
309
+ print ('hit mem cache:' , cache_path )
310
+ return cache_path
311
+
312
+ # Case 2: 文件在磁盘中
313
+ if os .path .exists (disk_cache_path ):
314
+ print ('hit disk cache:' , disk_cache_path )
315
+ return disk_cache_path
316
+
317
+ # Case 3: 等待其他进程将文件搬运到内存
318
+ try :
319
+ open (lock_path , 'x' )
320
+ except FileExistsError :
321
+ print ('waiting peer load:' , cache_path )
322
+ while not os .path .exists (cache_path ):
323
+ time .sleep (0.1 )
324
+ print ('peer done:' , cache_path )
325
+ return cache_path
326
+
327
+ # Case 4: 从其他机器的磁盘中取回
328
+ ckpt_id = int (basename .split ('-' )[1 ])
329
+ dst_rank = ckpt_id % int (os .environ ['TRAINERS_NUM' ])
330
+ dst_ip = os .environ ['TRIANER_IP_LIST' ].split (',' )[dst_rank ]
331
+ print ('fetching:' , f'root@{ dst_ip } :{ disk_cache_path } ' , '->' , lock_path )
332
+ if subprocess .run (['scp' , f'root@{ dst_ip } :{ disk_cache_path } ' , lock_path ]).returncode == 0 :
333
+ subprocess .run (['mv' , lock_path , cache_path ], check = True )
334
+ print (f'done fetch in { time .time () - begin :.3f} s:' , cache_path )
335
+ return cache_path
336
+
337
+ # Case 5: 从源地址取回
338
+ print ('copying:' , path , '->' , lock_path )
339
+ while subprocess .run (['cp' , path , lock_path ]).returncode :
340
+ print ('retrying:' , path , '->' , lock_path )
341
+ time .sleep (10 ) # sometimes too many open files cause error
342
+ subprocess .run (['mv' , lock_path , cache_path ], check = True )
343
+ print (f'done copy in { time .time () - begin :.3f} s:' , cache_path )
344
+ return cache_path
345
+
346
+
287
347
def load_huggingface_ckpt (model , huggingface_ckpt_path ):
288
348
ckpt_pre = huggingface_ckpt_path
289
349
@@ -328,8 +388,9 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
328
388
check_list = []
329
389
print ("Start load huggingface ckpt" )
330
390
for i , filename in enumerate (required_files ):
391
+ print (f'loading { i + 1 } /{ len (required_files )} : { filename } ' )
331
392
try :
332
- with safe_open (ckpt_pre + filename , framework = "paddle" , device = "cpu" ) as f :
393
+ with safe_open (hf_cache ( ckpt_pre + filename ) , framework = "paddle" , device = "cpu" ) as f :
333
394
# 加载该文件包含的所有参数
334
395
pd_params = file_to_pd_param_name [filename ]
335
396
for pd_param in pd_params :
@@ -359,12 +420,12 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
359
420
if weight_map [hf_name [0 ]] == filename :
360
421
tensor0 = f .get_tensor (hf_name [0 ])
361
422
with safe_open (
362
- ckpt_pre + weight_map [hf_name [1 ]], framework = "paddle" , device = "cpu"
423
+ hf_cache ( ckpt_pre + weight_map [hf_name [1 ]]) , framework = "paddle" , device = "cpu"
363
424
) as f_other :
364
425
tensor1 = f_other .get_tensor (hf_name [1 ])
365
426
else :
366
427
with safe_open (
367
- ckpt_pre + weight_map [hf_name [0 ]], framework = "paddle" , device = "cpu"
428
+ hf_cache ( ckpt_pre + weight_map [hf_name [0 ]]) , framework = "paddle" , device = "cpu"
368
429
) as f_other :
369
430
tensor0 = f_other .get_tensor (hf_name [0 ])
370
431
tensor1 = f .get_tensor (hf_name [1 ])
@@ -376,3 +437,4 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
376
437
except Exception as e :
377
438
print (f"Error loading { filename } : { str (e )} " )
378
439
raise
440
+ print ("End load huggingface ckpt" )
0 commit comments