@@ -1278,102 +1278,101 @@ def _initialize(self, mode, init_parameters=True):
1278
1278
paddle .distributed .ParallelEnv ().dev_id
1279
1279
)
1280
1280
1281
- if self ._in_pir_mode :
1282
- # FIXME(ljz) avoid shared same tensor more than once in different mode
1283
- if mode != "train" :
1284
- return
1285
- # TODO(2024-Q2)
1286
- # 1. unify random control
1287
- # 2. initialization of non-parameter buffer
1288
- # 3. run startup program for pir
1289
- # 4. lazy init adaption
1290
- # 5. amp init adaption
1291
- # 6. vpp init adaption
1292
-
1293
- # self._init_lr(self._pir_dense_main_progs[mode])
1294
- self .program_helper .init_pir (
1295
- self ._pir_dist_main_progs [mode ], self ._place
1296
- )
1297
- changed_output_op_list = []
1298
- if self ._executor is None :
1299
- self ._executor = paddle .static .Executor (self ._place )
1300
- startup_prog = self ._startup_progs [mode ].clone ()
1301
- dist_main_prog = self ._pir_dist_main_progs [mode ]
1302
- name_map_value = {}
1303
- for op in dist_main_prog .global_block ().ops :
1304
- if op .name () == "pd_op.data" :
1305
- var_name = op .str_attr ("name" )
1306
- assert (
1307
- var_name not in name_map_value
1308
- ), f"The value { var_name } in { op } is already exist"
1309
- name_map_value [var_name ] = op .result (0 )
1310
- del_ops = []
1311
- block = startup_prog .global_block ()
1312
- for op in block .ops :
1313
- if op .name () == "builtin.set_parameter" :
1314
- var_name = op .str_attr ("parameter_name" )
1315
- elif op .name () == "builtin.shadow_output" :
1316
- var_name = op .str_attr ("output_name" )
1317
- else :
1318
- continue
1319
- scope_var = global_scope ().find_var (var_name )
1320
- if scope_var and scope_var .get_tensor ()._is_initialized ():
1321
- param = op .operand_source (0 )
1322
- initial_op = param .get_defining_op ()
1323
- new_param = block .add_kwarg (var_name , param .type ())
1324
- new_param .persistable = True
1325
- new_param .place_attr = scope_var .get_tensor ()._place ()
1326
- param .replace_all_uses_with (new_param )
1327
- del_ops .append (op )
1328
- del_ops .append (initial_op )
1329
- elif var_name in name_map_value :
1330
- local_shape = name_map_value [var_name ]._local_shape
1331
- global_shape = name_map_value [var_name ].shape
1332
- if local_shape != global_shape :
1333
- src_value = op .operand_source (0 )
1334
- assert src_value .shape == global_shape
1335
- dst_dist_attr = name_map_value [var_name ].dist_attr ()
1336
- if not src_value .is_dist ():
1337
- src_dist_attr = paddle .base .libpaddle .pir .create_tensor_dist_attribute (
1338
- dst_dist_attr .process_mesh ,
1339
- [- 1 ] * len (src_value .shape ),
1340
- {},
1341
- )
1342
- src_value .set_type (
1343
- paddle .base .libpaddle .pir .cvt_to_dist_type (
1344
- src_value .type (), src_dist_attr
1345
- )
1346
- )
1347
- pir .set_insertion_point_after (
1348
- src_value .get_defining_op ()
1281
+ if not self ._in_pir_mode :
1282
+ raise NotImplementedError ("_initialize() only support PIR now." )
1283
+
1284
+ # FIXME(ljz) avoid shared same tensor more than once in different mode
1285
+ if mode != "train" :
1286
+ return
1287
+ # TODO(2024-Q2)
1288
+ # 1. unify random control
1289
+ # 2. initialization of non-parameter buffer
1290
+ # 3. run startup program for pir
1291
+ # 4. lazy init adaption
1292
+ # 5. amp init adaption
1293
+ # 6. vpp init adaption
1294
+
1295
+ # self._init_lr(self._pir_dense_main_progs[mode])
1296
+ self .program_helper .init_pir (
1297
+ self ._pir_dist_main_progs [mode ], self ._place
1298
+ )
1299
+ changed_output_op_list = []
1300
+ if self ._executor is None :
1301
+ self ._executor = paddle .static .Executor (self ._place )
1302
+ startup_prog = self ._startup_progs [mode ].clone ()
1303
+ dist_main_prog = self ._pir_dist_main_progs [mode ]
1304
+ name_map_value = {}
1305
+ for op in dist_main_prog .global_block ().ops :
1306
+ if op .name () == "pd_op.data" :
1307
+ var_name = op .str_attr ("name" )
1308
+ assert (
1309
+ var_name not in name_map_value
1310
+ ), f"The value { var_name } in { op } is already exist"
1311
+ name_map_value [var_name ] = op .result (0 )
1312
+ del_ops = []
1313
+ block = startup_prog .global_block ()
1314
+ for op in block .ops :
1315
+ if op .name () == "builtin.set_parameter" :
1316
+ var_name = op .str_attr ("parameter_name" )
1317
+ elif op .name () == "builtin.shadow_output" :
1318
+ var_name = op .str_attr ("output_name" )
1319
+ else :
1320
+ continue
1321
+ scope_var = global_scope ().find_var (var_name )
1322
+ if scope_var and scope_var .get_tensor ()._is_initialized ():
1323
+ param = op .operand_source (0 )
1324
+ initial_op = param .get_defining_op ()
1325
+ new_param = block .add_kwarg (var_name , param .type ())
1326
+ new_param .persistable = True
1327
+ new_param .place_attr = scope_var .get_tensor ()._place ()
1328
+ param .replace_all_uses_with (new_param )
1329
+ del_ops .append (op )
1330
+ del_ops .append (initial_op )
1331
+ elif var_name in name_map_value :
1332
+ local_shape = name_map_value [var_name ]._local_shape
1333
+ global_shape = name_map_value [var_name ].shape
1334
+ if local_shape != global_shape :
1335
+ src_value = op .operand_source (0 )
1336
+ assert src_value .shape == global_shape
1337
+ dst_dist_attr = name_map_value [var_name ].dist_attr ()
1338
+ if not src_value .is_dist ():
1339
+ src_dist_attr = paddle .base .libpaddle .pir .create_tensor_dist_attribute (
1340
+ dst_dist_attr .process_mesh ,
1341
+ [- 1 ] * len (src_value .shape ),
1342
+ {},
1349
1343
)
1350
- reshard_var = paddle ._C_ops .reshard_v2 (
1351
- src_value , dst_dist_attr
1344
+ src_value .set_type (
1345
+ paddle .base .libpaddle .pir .cvt_to_dist_type (
1346
+ src_value .type (), src_dist_attr
1347
+ )
1352
1348
)
1353
- if src_value .persistable :
1354
- src_value .persistable = False
1355
- changed_output_op_list .append (op )
1356
- op .operand (0 ).set_source (reshard_var )
1357
- for del_op in del_ops :
1358
- del_op .erase ()
1359
-
1360
- set_all_ops_op_role (startup_prog .global_block (), OpRole .Forward )
1361
- ReshardPasses .apply_reshard_pass (startup_prog )
1362
- paddle .base .libpaddle .pir .apply_dist2dense_pass (startup_prog )
1363
- remove_unuseful_comm_op_pass (startup_prog )
1364
-
1365
- for op in changed_output_op_list :
1366
- op .operand_source (0 ).persistable = True
1367
- self ._executor .run (startup_prog )
1368
- if self ._job_plan is not None :
1369
- # pipeline scheduling should be enabled after running
1370
- # startup program, otherwise the startup program cannot
1371
- # run correctly.
1372
- self ._executor ._set_plan (self ._job_plan )
1373
- return
1374
-
1375
- else :
1376
- raise NotImplementedError ("_initialize() only support PIR now." )
1349
+ pir .set_insertion_point_after (
1350
+ src_value .get_defining_op ()
1351
+ )
1352
+ reshard_var = paddle ._C_ops .reshard_v2 (
1353
+ src_value , dst_dist_attr
1354
+ )
1355
+ if src_value .persistable :
1356
+ src_value .persistable = False
1357
+ changed_output_op_list .append (op )
1358
+ op .operand (0 ).set_source (reshard_var )
1359
+ for del_op in del_ops :
1360
+ del_op .erase ()
1361
+
1362
+ set_all_ops_op_role (startup_prog .global_block (), OpRole .Forward )
1363
+ ReshardPasses .apply_reshard_pass (startup_prog )
1364
+ paddle .base .libpaddle .pir .apply_dist2dense_pass (startup_prog )
1365
+ remove_unuseful_comm_op_pass (startup_prog )
1366
+
1367
+ for op in changed_output_op_list :
1368
+ op .operand_source (0 ).persistable = True
1369
+ self ._executor .run (startup_prog )
1370
+ if self ._job_plan is not None :
1371
+ # pipeline scheduling should be enabled after running
1372
+ # startup program, otherwise the startup program cannot
1373
+ # run correctly.
1374
+ self ._executor ._set_plan (self ._job_plan )
1375
+ return
1377
1376
1378
1377
# distributed training combined with prim mechanism (prim is behind of distributed)
1379
1378
# for local main subprogram after distributed partition,
0 commit comments