@@ -277,13 +277,54 @@ def prepare_tensor(tensor, dst_shape, *, force_transpose=False):
277
277
if len (tensor .shape ) != 1 :
278
278
print ("attention same shape not transpose !!!!!!!!!!!!!!!!!!!!!!" )
279
279
return tensor
280
- if len (tensor .shape ) == 2 and paddle .transpose (tensor , perm = [1 , 0 ]).contiguous ().shape == dst_shape :
280
+
281
+ if len (tensor .shape ) == 2 :
282
+ num_experts , hidden_size = tensor .shape
283
+ assert hidden_size == dst_shape [0 ], f"Shape not match: { tensor .shape } { dst_shape } "
284
+ if num_experts != dst_shape [1 ]:
285
+ print (f"Slice weight: { tensor .shape } -> { dst_shape } " )
286
+ tensor = tensor [:dst_shape [1 ]]
281
287
return paddle .transpose (tensor , perm = [1 , 0 ]).contiguous ()
282
288
283
- print ("shape not match here" )
289
+ if len (tensor .shape ) == 1 :
290
+ print (f"Slice weight: { tensor .shape } -> { dst_shape } " )
291
+ tensor = tensor [:dst_shape [0 ]]
292
+ return tensor
293
+
294
+ print ("Fatal: shape not match here:" , tensor .shape , dst_shape )
284
295
sys .exit ()
285
296
286
297
298
+ def hf_cache (path ):
299
+ print ('looking up:' , path )
300
+ import os , time , subprocess
301
+ basename = 'lshrun_' + os .path .basename (path )
302
+ cache_path = os .path .join ('/dev/shm' , basename )
303
+ lock_path = cache_path + '.lock'
304
+
305
+ # Case 1: cache exists
306
+ if os .path .exists (cache_path ):
307
+ print ('hit cache:' , cache_path )
308
+ return cache_path
309
+
310
+ try :
311
+ open (lock_path , 'x' )
312
+ except FileExistsError :
313
+ # Case 2: peer is loading
314
+ print ('waiting peer load:' , lock_path )
315
+ while os .path .exists (lock_path ):
316
+ time .sleep (0.1 )
317
+ print ('peer done:' , lock_path )
318
+ else :
319
+ # Case 3: load it ourself
320
+ print ('copying:' , lock_path )
321
+ subprocess .run (['cp' , path , lock_path ], check = True )
322
+ subprocess .run (['mv' , lock_path , cache_path ], check = True )
323
+ print ('done copy:' , lock_path )
324
+
325
+ return cache_path
326
+
327
+
287
328
def load_huggingface_ckpt (model , huggingface_ckpt_path ):
288
329
ckpt_pre = huggingface_ckpt_path
289
330
@@ -328,8 +369,9 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
328
369
check_list = []
329
370
print ("Start load huggingface ckpt" )
330
371
for i , filename in enumerate (required_files ):
372
+ print (f'loading { i + 1 } /{ len (required_files )} : { filename } ' )
331
373
try :
332
- with safe_open (ckpt_pre + filename , framework = "paddle" , device = "cpu" ) as f :
374
+ with safe_open (hf_cache ( ckpt_pre + filename ) , framework = "paddle" , device = "cpu" ) as f :
333
375
# 加载该文件包含的所有参数
334
376
pd_params = file_to_pd_param_name [filename ]
335
377
for pd_param in pd_params :
@@ -359,12 +401,12 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
359
401
if weight_map [hf_name [0 ]] == filename :
360
402
tensor0 = f .get_tensor (hf_name [0 ])
361
403
with safe_open (
362
- ckpt_pre + weight_map [hf_name [1 ]], framework = "paddle" , device = "cpu"
404
+ hf_cache ( ckpt_pre + weight_map [hf_name [1 ]]) , framework = "paddle" , device = "cpu"
363
405
) as f_other :
364
406
tensor1 = f_other .get_tensor (hf_name [1 ])
365
407
else :
366
408
with safe_open (
367
- ckpt_pre + weight_map [hf_name [0 ]], framework = "paddle" , device = "cpu"
409
+ hf_cache ( ckpt_pre + weight_map [hf_name [0 ]]) , framework = "paddle" , device = "cpu"
368
410
) as f_other :
369
411
tensor0 = f_other .get_tensor (hf_name [0 ])
370
412
tensor1 = f .get_tensor (hf_name [1 ])
@@ -376,3 +418,4 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
376
418
except Exception as e :
377
419
print (f"Error loading { filename } : { str (e )} " )
378
420
raise
421
+ print ("End load huggingface ckpt" )
0 commit comments