@@ -273,10 +273,45 @@ def set_skip_gc_vars(num_micro_batches, job_types, sub_programs, jobs):
273
273
num_micro_batches , job_types , sub_programs , jobs
274
274
)
275
275
else :
276
- raise NotImplementedError (
277
- "The function set_skip_gc_vars is not supported in the old IR."
276
+ return _set_skip_gc_vars_in_old_ir (
277
+ num_micro_batches , job_types , sub_programs , jobs
278
+ )
279
+
280
+
281
+ def _set_skip_gc_vars_in_old_ir (
282
+ num_micro_batches , job_types , sub_programs , jobs
283
+ ):
284
+ assert num_micro_batches >= 1 , "num_micro_batches needs to be >= 1"
285
+ type_to_program = dict (zip (job_types , sub_programs ))
286
+
287
+ # step1: Get all vars of every sub_program that are non-persistable and not in op's no_need_buffer.
288
+ type_to_required_vars = {}
289
+ for type , program in type_to_program .items ():
290
+ type_to_required_vars [type ] = _get_required_vars_of_program (program )
291
+
292
+ # step2: Set `skip_gc_vars` for each job
293
+ suffixed_required_vars = [set () for i in range (num_micro_batches )]
294
+ num_jobs = len (jobs )
295
+ for job_id in reversed (range (num_jobs )):
296
+ job = jobs [job_id ]
297
+ job_type = job .type ()
298
+ required_vars = type_to_required_vars [job_type ]
299
+ micro_batch_id = job .micro_batch_id ()
300
+ skip_gc_vars = required_vars & suffixed_required_vars [micro_batch_id ]
301
+ logger .debug (
302
+ f"Skip gc vars for { job_type } -({ micro_batch_id } ): { skip_gc_vars } "
278
303
)
279
304
305
+ if job_type in ["backward" , "backward_w" ]:
306
+ assert (
307
+ len (skip_gc_vars ) == 0
308
+ ), f"When enabling pipeline parallelism strategy, the skip_gc_vars for { job_type } subprogram must be empty, but it is { skip_gc_vars } ."
309
+
310
+ job .set_skip_gc_vars (skip_gc_vars )
311
+ suffixed_required_vars [micro_batch_id ] |= required_vars
312
+
313
+ return type_to_program
314
+
280
315
281
316
def _set_skip_gc_vars_in_pir (num_micro_batches , job_types , sub_programs , jobs ):
282
317
assert num_micro_batches >= 1 , "num_micro_batches needs to be >= 1"
0 commit comments