diff --git a/python/paddle/distributed/fleet/recompute/__init__.py b/python/paddle/distributed/fleet/recompute/__init__.py index 2d425c17dfe67..e9455238a98ba 100644 --- a/python/paddle/distributed/fleet/recompute/__init__.py +++ b/python/paddle/distributed/fleet/recompute/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .recompute import recompute, recompute_sequential # noqa: F401 +from .recompute import ( # noqa: F401 + custom_state_manager, + recompute, + recompute_sequential, +) from .recompute_hybrid import recompute_hybrid # noqa: F401 __all__ = [] diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index bc109cf94dad6..fd3a39c89d718 100644 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -115,8 +115,42 @@ def check_recompute_necessary(inputs): ) +class CustomStatesManager: + """CustomStatesManager""" + + def __init__(self): + """__init__""" + self.custom_get_state_func = None + self.custom_set_state_func = None + + def set_custom_get_state_func(self, custom_get_state_func): + assert_msg = ( + "The custom_state_manager does not support duplicate settings." + ) + assert self.custom_get_state_func is None, assert_msg + self.custom_get_state_func = custom_get_state_func + + def set_custom_set_state_func(self, custom_set_state_func): + assert_msg = ( + "The custom_state_manager does not support duplicate settings." + ) + assert self.custom_set_state_func is None, assert_msg + self.custom_set_state_func = custom_set_state_func + + +custom_state_manager = CustomStatesManager() + + @contextlib.contextmanager -def switch_rng_state_tracker(rng_state, tracker, numpy_state, random_state): +def switch_rng_state_tracker( + rng_state, + tracker, + numpy_state, + random_state, + custom_state=None, + custom_get_state_func=None, + custom_set_state_func=None, +): orig_rng_state = paddle.get_rng_state() orig_rng_tracker = get_rng_state_tracker().get_states_tracker() paddle.set_rng_state(rng_state) @@ -126,6 +160,12 @@ def switch_rng_state_tracker(rng_state, tracker, numpy_state, random_state): orig_random_state = random.getstate() np.random.set_state(numpy_state) random.setstate(random_state) + + if custom_state is not None: + assert custom_get_state_func is not None + assert custom_set_state_func is not None + orig_custom_state = custom_get_state_func() + custom_set_state_func(custom_state) try: yield finally: @@ -134,11 +174,21 @@ def switch_rng_state_tracker(rng_state, tracker, numpy_state, random_state): np.random.set_state(orig_numpy_state) random.setstate(orig_random_state) + if custom_state is not None: + custom_set_state_func(orig_custom_state) + class RecomputeFunction(PyLayer): @staticmethod def forward( - ctx, run_function, preserve_rng_state, offload_indices, *args, **kwargs + ctx, + run_function, + preserve_rng_state, + offload_indices, + custom_get_state_func, + custom_set_state_func, + *args, + **kwargs, ): # store for recomputing ctx.run_function = run_function @@ -159,6 +209,9 @@ def forward( ) ctx.fwd_numpy_state = np.random.get_state() ctx.fwd_random_state = random.getstate() + ctx.fwd_custom_state = custom_get_state_func() + ctx.custom_get_state_func = custom_get_state_func + ctx.custom_set_state_func = custom_set_state_func # TODO support AMP tracer = framework._dygraph_tracer() @@ -268,6 +321,9 @@ def backward(ctx, *args): ctx.fwd_rng_state_tracker, ctx.fwd_numpy_state, ctx.fwd_random_state, + ctx.fwd_custom_state, + ctx.custom_get_state_func, + ctx.custom_set_state_func, ): with paddle.amp.auto_cast( enable=ctx.is_fw_autocast, @@ -342,7 +398,12 @@ def backward(ctx, *args): def _recompute_without_reentrant( - function, preserve_rng_state=True, *args, **kwargs + function, + custom_get_state_func, + custom_set_state_func, + preserve_rng_state=True, + *args, + **kwargs, ): """ recompute without reentrant, that means use hook to implement the recompute function rather than re-entrant autograd. @@ -370,6 +431,7 @@ def _recompute_without_reentrant( ) fwd_numpy_state = np.random.get_state() fwd_random_state = random.getstate() + fwd_custom_state = custom_get_state_func() tracer = framework._dygraph_tracer() is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True @@ -445,6 +507,9 @@ def inner_unpack(inner_x): fwd_cuda_rng_state_tracker, fwd_numpy_state, fwd_random_state, + fwd_custom_state, + custom_get_state_func, + custom_set_state_func, ): with paddle.set_grad_enabled(True): with paddle.amp.auto_cast( @@ -600,6 +665,14 @@ def recompute(function, *args, **kwargs): # whether to use reentrant method to implement recompute use_reentrant = kwargs.pop('use_reentrant', True) + if custom_state_manager.custom_get_state_func is None: + assert custom_state_manager.custom_set_state_func is None + custom_get_state_func = lambda x=None: None + custom_set_state_func = lambda x=None: None + else: + custom_get_state_func = custom_state_manager.custom_get_state_func + custom_set_state_func = custom_state_manager.custom_set_state_func + if not in_dynamic_mode(): from paddle.distributed.auto_parallel.interface import ( recompute as static_auto_recompute, @@ -644,10 +717,22 @@ def recompute(function, *args, **kwargs): raise ValueError("Unknown parameter kind.") return RecomputeFunction.apply( - function, preserve, offload_indices, *input_args + function, + preserve, + offload_indices, + custom_get_state_func, + custom_set_state_func, + *input_args, ) else: - return _recompute_without_reentrant(function, preserve, *args, **kwargs) + return _recompute_without_reentrant( + function, + custom_get_state_func, + custom_set_state_func, + preserve, + *args, + **kwargs, + ) def recompute_sequential( diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py index ded2893403732..a5dd84f7e023c 100644 --- a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -27,6 +27,7 @@ from ..meta_parallel.pp_utils import utils from .recompute import ( check_recompute_necessary, + custom_state_manager, detach_variable, switch_rng_state_tracker, ) @@ -101,6 +102,8 @@ def forward( mp_group, offload, partition, + custom_get_state_func, + custom_set_state_func, *args, **kwargs, ): @@ -114,6 +117,9 @@ def forward( ctx.fwd_rng_state_tracker = get_rng_state_tracker().get_states_tracker() ctx.fwd_numpy_state = np.random.get_state() ctx.fwd_random_state = random.getstate() + ctx.fwd_custom_state = custom_get_state_func() + ctx.custom_get_state_func = custom_get_state_func + ctx.custom_set_state_func = custom_set_state_func # save config info ctx.mp_group = mp_group @@ -223,6 +229,9 @@ def backward(ctx, *args): ctx.fwd_rng_state_tracker, ctx.fwd_numpy_state, ctx.fwd_random_state, + ctx.fwd_custom_state, + ctx.custom_get_state_func, + ctx.custom_set_state_func, ): if ctx.is_fw_autocast: with paddle.amp.auto_cast( @@ -307,9 +316,25 @@ def recompute_hybrid( if framework._dygraph_tracer()._has_grad: check_recompute_necessary(args) + if custom_state_manager.custom_get_state_func is None: + assert custom_state_manager.custom_set_state_func is None + custom_get_state_func = lambda x=None: None + custom_set_state_func = lambda x=None: None + else: + custom_get_state_func = custom_state_manager.custom_get_state_func + custom_set_state_func = custom_state_manager.custom_set_state_func + all_outputs = [] _HPRecomputeFunction.apply( - function, all_outputs, mp_group, offload, partition, *args, **kwargs + function, + all_outputs, + mp_group, + offload, + partition, + custom_get_state_func, + custom_set_state_func, + *args, + **kwargs, ) if len(all_outputs) == 1: diff --git a/test/collective/fleet/test_dygraph_recompute_for_eager.py b/test/collective/fleet/test_dygraph_recompute_for_eager.py index 0c45a490609a4..9475c8a1493e4 100644 --- a/test/collective/fleet/test_dygraph_recompute_for_eager.py +++ b/test/collective/fleet/test_dygraph_recompute_for_eager.py @@ -413,5 +413,130 @@ def test_recompute_inputs_with_tuple(self): self.assertEqual(grad_ref, grad) +class RandomManager: + def __init__(self): + self.global_random = random.Random(11) + + custom_state_manager.set_custom_get_state_func(self.get_states) + custom_state_manager.set_custom_set_state_func(self.set_states) + + def get_states(self): + return self.global_random.getstate() + + def set_states(self, packed_states): + self.global_random.setstate(packed_states) + + +class RandomLayer(paddle.nn.Layer): + def __init__(self, random_handle=None): + super().__init__() + self.random_handle = random_handle + + def forward(self, input): + if self.random_handle is not None: + random_val = self.random_handle.global_random.random() + else: + random_val = random.random() + return random_val * input + + +from paddle.distributed.fleet.recompute import custom_state_manager + + +class Random_fc_net(paddle.nn.Layer): + def __init__(self, input_size, random_handle=None, use_recompute=False): + super().__init__() + self.use_recompute = use_recompute + self.runfunc0 = get_fc_block(0, input_size, is_last=False) + self.random0 = RandomLayer(random_handle) + self.runfunc1 = get_fc_block(1, input_size, is_last=False) + self.random1 = RandomLayer(random_handle) + self.runfunc2 = get_fc_block(2, input_size, is_last=False) + self.random2 = RandomLayer(random_handle) + self.runfunc3 = get_fc_block(3, input_size, is_last=False) + self.random3 = RandomLayer(random_handle) + self.runfunc4 = get_fc_block(4, input_size, is_last=True) + + def forward_impl(self, input): + o = self.runfunc0(input) + o = self.random0(o) + o = self.runfunc1(o) + o = self.random1(o) + o = self.runfunc2(o) + o = self.random2(o) + o = self.runfunc3(o) + o = self.random3(o) + o = self.runfunc4(o) + return o + + def forward(self, input): + if self.use_recompute: + return recompute(self.forward_impl, input) + else: + return self.forward_impl(input) + + +class TesRandomStatesInRecompute(unittest.TestCase): + def run_model(self, batch_size, input_size, random_handle, use_recompute): + gen = paddle.seed(10) + gen.manual_seed(10) + np.random.seed(10) + random.seed(10) + model = Random_fc_net( + input_size=input_size, + random_handle=random_handle, + use_recompute=use_recompute, + ) + loss_fn = paddle.nn.MSELoss(reduction='mean') + optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=model.parameters() + ) + loss_ = [] + for step in range(10): + x_data = np.random.randn(batch_size, input_size).astype(np.float32) + x = paddle.to_tensor(x_data) + x.stop_gradient = False + y_pred = model(x) + loss = y_pred.mean() + loss_.append(np.asarray(loss).tolist()) + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss_ + + def test_custom_random_manager(self): + input_size = 1024 + batch_size = 2 + + random_handle = RandomManager() + + random_handle.global_random.seed(42) + loss_recompute = self.run_model( + batch_size, input_size, random_handle, use_recompute=True + ) + + random_handle.global_random.seed(42) + loss_ref = self.run_model( + batch_size, input_size, random_handle, use_recompute=False + ) + + for loss1, loss2 in zip(loss_recompute, loss_ref): + self.assertEqual(loss1, loss2) + + def test_normal_random(self): + input_size = 1024 + batch_size = 2 + + loss_recompute = self.run_model( + batch_size, input_size, random_handle=None, use_recompute=True + ) + loss_ref = self.run_model( + batch_size, input_size, random_handle=None, use_recompute=False + ) + + for loss1, loss2 in zip(loss_recompute, loss_ref): + self.assertEqual(loss1, loss2) + + if __name__ == '__main__': unittest.main()