Skip to content

Commit dcf2e37

Browse files
committed
update
1 parent 44b005a commit dcf2e37

File tree

1 file changed

+37
-2
lines changed

1 file changed

+37
-2
lines changed

python/paddle/distributed/passes/pass_utils.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,45 @@ def set_skip_gc_vars(num_micro_batches, job_types, sub_programs, jobs):
273273
num_micro_batches, job_types, sub_programs, jobs
274274
)
275275
else:
276-
raise NotImplementedError(
277-
"The function set_skip_gc_vars is not supported in the old IR."
276+
return _set_skip_gc_vars_in_old_ir(
277+
num_micro_batches, job_types, sub_programs, jobs
278+
)
279+
280+
281+
def _set_skip_gc_vars_in_old_ir(
282+
num_micro_batches, job_types, sub_programs, jobs
283+
):
284+
assert num_micro_batches >= 1, "num_micro_batches needs to be >= 1"
285+
type_to_program = dict(zip(job_types, sub_programs))
286+
287+
# step1: Get all vars of every sub_program that are non-persistable and not in op's no_need_buffer.
288+
type_to_required_vars = {}
289+
for type, program in type_to_program.items():
290+
type_to_required_vars[type] = _get_required_vars_of_program(program)
291+
292+
# step2: Set `skip_gc_vars` for each job
293+
suffixed_required_vars = [set() for i in range(num_micro_batches)]
294+
num_jobs = len(jobs)
295+
for job_id in reversed(range(num_jobs)):
296+
job = jobs[job_id]
297+
job_type = job.type()
298+
required_vars = type_to_required_vars[job_type]
299+
micro_batch_id = job.micro_batch_id()
300+
skip_gc_vars = required_vars & suffixed_required_vars[micro_batch_id]
301+
logger.debug(
302+
f"Skip gc vars for {job_type}-({micro_batch_id}): {skip_gc_vars}"
278303
)
279304

305+
if job_type in ["backward", "backward_w"]:
306+
assert (
307+
len(skip_gc_vars) == 0
308+
), f"When enabling pipeline parallelism strategy, the skip_gc_vars for {job_type} subprogram must be empty, but it is {skip_gc_vars}."
309+
310+
job.set_skip_gc_vars(skip_gc_vars)
311+
suffixed_required_vars[micro_batch_id] |= required_vars
312+
313+
return type_to_program
314+
280315

281316
def _set_skip_gc_vars_in_pir(num_micro_batches, job_types, sub_programs, jobs):
282317
assert num_micro_batches >= 1, "num_micro_batches needs to be >= 1"

0 commit comments

Comments
 (0)