Skip to content

Commit 46cb554

Browse files
committed
update
1 parent b8bdfbc commit 46cb554

File tree

2 files changed

+93
-283
lines changed

2 files changed

+93
-283
lines changed

python/paddle/distributed/auto_parallel/static/engine.py

Lines changed: 93 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,102 +1278,101 @@ def _initialize(self, mode, init_parameters=True):
12781278
paddle.distributed.ParallelEnv().dev_id
12791279
)
12801280

1281-
if self._in_pir_mode:
1282-
# FIXME(ljz) avoid shared same tensor more than once in different mode
1283-
if mode != "train":
1284-
return
1285-
# TODO(2024-Q2)
1286-
# 1. unify random control
1287-
# 2. initialization of non-parameter buffer
1288-
# 3. run startup program for pir
1289-
# 4. lazy init adaption
1290-
# 5. amp init adaption
1291-
# 6. vpp init adaption
1292-
1293-
# self._init_lr(self._pir_dense_main_progs[mode])
1294-
self.program_helper.init_pir(
1295-
self._pir_dist_main_progs[mode], self._place
1296-
)
1297-
changed_output_op_list = []
1298-
if self._executor is None:
1299-
self._executor = paddle.static.Executor(self._place)
1300-
startup_prog = self._startup_progs[mode].clone()
1301-
dist_main_prog = self._pir_dist_main_progs[mode]
1302-
name_map_value = {}
1303-
for op in dist_main_prog.global_block().ops:
1304-
if op.name() == "pd_op.data":
1305-
var_name = op.str_attr("name")
1306-
assert (
1307-
var_name not in name_map_value
1308-
), f"The value {var_name} in {op} is already exist"
1309-
name_map_value[var_name] = op.result(0)
1310-
del_ops = []
1311-
block = startup_prog.global_block()
1312-
for op in block.ops:
1313-
if op.name() == "builtin.set_parameter":
1314-
var_name = op.str_attr("parameter_name")
1315-
elif op.name() == "builtin.shadow_output":
1316-
var_name = op.str_attr("output_name")
1317-
else:
1318-
continue
1319-
scope_var = global_scope().find_var(var_name)
1320-
if scope_var and scope_var.get_tensor()._is_initialized():
1321-
param = op.operand_source(0)
1322-
initial_op = param.get_defining_op()
1323-
new_param = block.add_kwarg(var_name, param.type())
1324-
new_param.persistable = True
1325-
new_param.place_attr = scope_var.get_tensor()._place()
1326-
param.replace_all_uses_with(new_param)
1327-
del_ops.append(op)
1328-
del_ops.append(initial_op)
1329-
elif var_name in name_map_value:
1330-
local_shape = name_map_value[var_name]._local_shape
1331-
global_shape = name_map_value[var_name].shape
1332-
if local_shape != global_shape:
1333-
src_value = op.operand_source(0)
1334-
assert src_value.shape == global_shape
1335-
dst_dist_attr = name_map_value[var_name].dist_attr()
1336-
if not src_value.is_dist():
1337-
src_dist_attr = paddle.base.libpaddle.pir.create_tensor_dist_attribute(
1338-
dst_dist_attr.process_mesh,
1339-
[-1] * len(src_value.shape),
1340-
{},
1341-
)
1342-
src_value.set_type(
1343-
paddle.base.libpaddle.pir.cvt_to_dist_type(
1344-
src_value.type(), src_dist_attr
1345-
)
1346-
)
1347-
pir.set_insertion_point_after(
1348-
src_value.get_defining_op()
1281+
if not self._in_pir_mode:
1282+
raise NotImplementedError("_initialize() only support PIR now.")
1283+
1284+
# FIXME(ljz) avoid shared same tensor more than once in different mode
1285+
if mode != "train":
1286+
return
1287+
# TODO(2024-Q2)
1288+
# 1. unify random control
1289+
# 2. initialization of non-parameter buffer
1290+
# 3. run startup program for pir
1291+
# 4. lazy init adaption
1292+
# 5. amp init adaption
1293+
# 6. vpp init adaption
1294+
1295+
# self._init_lr(self._pir_dense_main_progs[mode])
1296+
self.program_helper.init_pir(
1297+
self._pir_dist_main_progs[mode], self._place
1298+
)
1299+
changed_output_op_list = []
1300+
if self._executor is None:
1301+
self._executor = paddle.static.Executor(self._place)
1302+
startup_prog = self._startup_progs[mode].clone()
1303+
dist_main_prog = self._pir_dist_main_progs[mode]
1304+
name_map_value = {}
1305+
for op in dist_main_prog.global_block().ops:
1306+
if op.name() == "pd_op.data":
1307+
var_name = op.str_attr("name")
1308+
assert (
1309+
var_name not in name_map_value
1310+
), f"The value {var_name} in {op} is already exist"
1311+
name_map_value[var_name] = op.result(0)
1312+
del_ops = []
1313+
block = startup_prog.global_block()
1314+
for op in block.ops:
1315+
if op.name() == "builtin.set_parameter":
1316+
var_name = op.str_attr("parameter_name")
1317+
elif op.name() == "builtin.shadow_output":
1318+
var_name = op.str_attr("output_name")
1319+
else:
1320+
continue
1321+
scope_var = global_scope().find_var(var_name)
1322+
if scope_var and scope_var.get_tensor()._is_initialized():
1323+
param = op.operand_source(0)
1324+
initial_op = param.get_defining_op()
1325+
new_param = block.add_kwarg(var_name, param.type())
1326+
new_param.persistable = True
1327+
new_param.place_attr = scope_var.get_tensor()._place()
1328+
param.replace_all_uses_with(new_param)
1329+
del_ops.append(op)
1330+
del_ops.append(initial_op)
1331+
elif var_name in name_map_value:
1332+
local_shape = name_map_value[var_name]._local_shape
1333+
global_shape = name_map_value[var_name].shape
1334+
if local_shape != global_shape:
1335+
src_value = op.operand_source(0)
1336+
assert src_value.shape == global_shape
1337+
dst_dist_attr = name_map_value[var_name].dist_attr()
1338+
if not src_value.is_dist():
1339+
src_dist_attr = paddle.base.libpaddle.pir.create_tensor_dist_attribute(
1340+
dst_dist_attr.process_mesh,
1341+
[-1] * len(src_value.shape),
1342+
{},
13491343
)
1350-
reshard_var = paddle._C_ops.reshard_v2(
1351-
src_value, dst_dist_attr
1344+
src_value.set_type(
1345+
paddle.base.libpaddle.pir.cvt_to_dist_type(
1346+
src_value.type(), src_dist_attr
1347+
)
13521348
)
1353-
if src_value.persistable:
1354-
src_value.persistable = False
1355-
changed_output_op_list.append(op)
1356-
op.operand(0).set_source(reshard_var)
1357-
for del_op in del_ops:
1358-
del_op.erase()
1359-
1360-
set_all_ops_op_role(startup_prog.global_block(), OpRole.Forward)
1361-
ReshardPasses.apply_reshard_pass(startup_prog)
1362-
paddle.base.libpaddle.pir.apply_dist2dense_pass(startup_prog)
1363-
remove_unuseful_comm_op_pass(startup_prog)
1364-
1365-
for op in changed_output_op_list:
1366-
op.operand_source(0).persistable = True
1367-
self._executor.run(startup_prog)
1368-
if self._job_plan is not None:
1369-
# pipeline scheduling should be enabled after running
1370-
# startup program, otherwise the startup program cannot
1371-
# run correctly.
1372-
self._executor._set_plan(self._job_plan)
1373-
return
1374-
1375-
else:
1376-
raise NotImplementedError("_initialize() only support PIR now.")
1349+
pir.set_insertion_point_after(
1350+
src_value.get_defining_op()
1351+
)
1352+
reshard_var = paddle._C_ops.reshard_v2(
1353+
src_value, dst_dist_attr
1354+
)
1355+
if src_value.persistable:
1356+
src_value.persistable = False
1357+
changed_output_op_list.append(op)
1358+
op.operand(0).set_source(reshard_var)
1359+
for del_op in del_ops:
1360+
del_op.erase()
1361+
1362+
set_all_ops_op_role(startup_prog.global_block(), OpRole.Forward)
1363+
ReshardPasses.apply_reshard_pass(startup_prog)
1364+
paddle.base.libpaddle.pir.apply_dist2dense_pass(startup_prog)
1365+
remove_unuseful_comm_op_pass(startup_prog)
1366+
1367+
for op in changed_output_op_list:
1368+
op.operand_source(0).persistable = True
1369+
self._executor.run(startup_prog)
1370+
if self._job_plan is not None:
1371+
# pipeline scheduling should be enabled after running
1372+
# startup program, otherwise the startup program cannot
1373+
# run correctly.
1374+
self._executor._set_plan(self._job_plan)
1375+
return
13771376

13781377
# distributed training combined with prim mechanism (prim is behind of distributed)
13791378
# for local main subprogram after distributed partition,

0 commit comments

Comments
 (0)