Skip to content

Commit 4385ad1

Browse files
[Distributed] Support custom state setter and getter for recompute to save custom states (#72670)
* [Distributed] Support custom state setter and getter for recompute to save custom states * add ut
1 parent 3139a0a commit 4385ad1

File tree

4 files changed

+246
-7
lines changed

4 files changed

+246
-7
lines changed

python/paddle/distributed/fleet/recompute/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .recompute import recompute, recompute_sequential # noqa: F401
15+
from .recompute import ( # noqa: F401
16+
custom_state_manager,
17+
recompute,
18+
recompute_sequential,
19+
)
1620
from .recompute_hybrid import recompute_hybrid # noqa: F401
1721

1822
__all__ = []

python/paddle/distributed/fleet/recompute/recompute.py

+90-5
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,42 @@ def check_recompute_necessary(inputs):
115115
)
116116

117117

118+
class CustomStatesManager:
119+
"""CustomStatesManager"""
120+
121+
def __init__(self):
122+
"""__init__"""
123+
self.custom_get_state_func = None
124+
self.custom_set_state_func = None
125+
126+
def set_custom_get_state_func(self, custom_get_state_func):
127+
assert_msg = (
128+
"The custom_state_manager does not support duplicate settings."
129+
)
130+
assert self.custom_get_state_func is None, assert_msg
131+
self.custom_get_state_func = custom_get_state_func
132+
133+
def set_custom_set_state_func(self, custom_set_state_func):
134+
assert_msg = (
135+
"The custom_state_manager does not support duplicate settings."
136+
)
137+
assert self.custom_set_state_func is None, assert_msg
138+
self.custom_set_state_func = custom_set_state_func
139+
140+
141+
custom_state_manager = CustomStatesManager()
142+
143+
118144
@contextlib.contextmanager
119-
def switch_rng_state_tracker(rng_state, tracker, numpy_state, random_state):
145+
def switch_rng_state_tracker(
146+
rng_state,
147+
tracker,
148+
numpy_state,
149+
random_state,
150+
custom_state=None,
151+
custom_get_state_func=None,
152+
custom_set_state_func=None,
153+
):
120154
orig_rng_state = paddle.get_rng_state()
121155
orig_rng_tracker = get_rng_state_tracker().get_states_tracker()
122156
paddle.set_rng_state(rng_state)
@@ -126,6 +160,12 @@ def switch_rng_state_tracker(rng_state, tracker, numpy_state, random_state):
126160
orig_random_state = random.getstate()
127161
np.random.set_state(numpy_state)
128162
random.setstate(random_state)
163+
164+
if custom_state is not None:
165+
assert custom_get_state_func is not None
166+
assert custom_set_state_func is not None
167+
orig_custom_state = custom_get_state_func()
168+
custom_set_state_func(custom_state)
129169
try:
130170
yield
131171
finally:
@@ -134,11 +174,21 @@ def switch_rng_state_tracker(rng_state, tracker, numpy_state, random_state):
134174
np.random.set_state(orig_numpy_state)
135175
random.setstate(orig_random_state)
136176

177+
if custom_state is not None:
178+
custom_set_state_func(orig_custom_state)
179+
137180

138181
class RecomputeFunction(PyLayer):
139182
@staticmethod
140183
def forward(
141-
ctx, run_function, preserve_rng_state, offload_indices, *args, **kwargs
184+
ctx,
185+
run_function,
186+
preserve_rng_state,
187+
offload_indices,
188+
custom_get_state_func,
189+
custom_set_state_func,
190+
*args,
191+
**kwargs,
142192
):
143193
# store for recomputing
144194
ctx.run_function = run_function
@@ -159,6 +209,9 @@ def forward(
159209
)
160210
ctx.fwd_numpy_state = np.random.get_state()
161211
ctx.fwd_random_state = random.getstate()
212+
ctx.fwd_custom_state = custom_get_state_func()
213+
ctx.custom_get_state_func = custom_get_state_func
214+
ctx.custom_set_state_func = custom_set_state_func
162215

163216
# TODO support AMP
164217
tracer = framework._dygraph_tracer()
@@ -268,6 +321,9 @@ def backward(ctx, *args):
268321
ctx.fwd_rng_state_tracker,
269322
ctx.fwd_numpy_state,
270323
ctx.fwd_random_state,
324+
ctx.fwd_custom_state,
325+
ctx.custom_get_state_func,
326+
ctx.custom_set_state_func,
271327
):
272328
with paddle.amp.auto_cast(
273329
enable=ctx.is_fw_autocast,
@@ -342,7 +398,12 @@ def backward(ctx, *args):
342398

343399

344400
def _recompute_without_reentrant(
345-
function, preserve_rng_state=True, *args, **kwargs
401+
function,
402+
custom_get_state_func,
403+
custom_set_state_func,
404+
preserve_rng_state=True,
405+
*args,
406+
**kwargs,
346407
):
347408
"""
348409
recompute without reentrant, that means use hook to implement the recompute function rather than re-entrant autograd.
@@ -370,6 +431,7 @@ def _recompute_without_reentrant(
370431
)
371432
fwd_numpy_state = np.random.get_state()
372433
fwd_random_state = random.getstate()
434+
fwd_custom_state = custom_get_state_func()
373435

374436
tracer = framework._dygraph_tracer()
375437
is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
@@ -445,6 +507,9 @@ def inner_unpack(inner_x):
445507
fwd_cuda_rng_state_tracker,
446508
fwd_numpy_state,
447509
fwd_random_state,
510+
fwd_custom_state,
511+
custom_get_state_func,
512+
custom_set_state_func,
448513
):
449514
with paddle.set_grad_enabled(True):
450515
with paddle.amp.auto_cast(
@@ -600,6 +665,14 @@ def recompute(function, *args, **kwargs):
600665
# whether to use reentrant method to implement recompute
601666
use_reentrant = kwargs.pop('use_reentrant', True)
602667

668+
if custom_state_manager.custom_get_state_func is None:
669+
assert custom_state_manager.custom_set_state_func is None
670+
custom_get_state_func = lambda x=None: None
671+
custom_set_state_func = lambda x=None: None
672+
else:
673+
custom_get_state_func = custom_state_manager.custom_get_state_func
674+
custom_set_state_func = custom_state_manager.custom_set_state_func
675+
603676
if not in_dynamic_mode():
604677
from paddle.distributed.auto_parallel.interface import (
605678
recompute as static_auto_recompute,
@@ -644,10 +717,22 @@ def recompute(function, *args, **kwargs):
644717
raise ValueError("Unknown parameter kind.")
645718

646719
return RecomputeFunction.apply(
647-
function, preserve, offload_indices, *input_args
720+
function,
721+
preserve,
722+
offload_indices,
723+
custom_get_state_func,
724+
custom_set_state_func,
725+
*input_args,
648726
)
649727
else:
650-
return _recompute_without_reentrant(function, preserve, *args, **kwargs)
728+
return _recompute_without_reentrant(
729+
function,
730+
custom_get_state_func,
731+
custom_set_state_func,
732+
preserve,
733+
*args,
734+
**kwargs,
735+
)
651736

652737

653738
def recompute_sequential(

python/paddle/distributed/fleet/recompute/recompute_hybrid.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ..meta_parallel.pp_utils import utils
2828
from .recompute import (
2929
check_recompute_necessary,
30+
custom_state_manager,
3031
detach_variable,
3132
switch_rng_state_tracker,
3233
)
@@ -101,6 +102,8 @@ def forward(
101102
mp_group,
102103
offload,
103104
partition,
105+
custom_get_state_func,
106+
custom_set_state_func,
104107
*args,
105108
**kwargs,
106109
):
@@ -114,6 +117,9 @@ def forward(
114117
ctx.fwd_rng_state_tracker = get_rng_state_tracker().get_states_tracker()
115118
ctx.fwd_numpy_state = np.random.get_state()
116119
ctx.fwd_random_state = random.getstate()
120+
ctx.fwd_custom_state = custom_get_state_func()
121+
ctx.custom_get_state_func = custom_get_state_func
122+
ctx.custom_set_state_func = custom_set_state_func
117123

118124
# save config info
119125
ctx.mp_group = mp_group
@@ -223,6 +229,9 @@ def backward(ctx, *args):
223229
ctx.fwd_rng_state_tracker,
224230
ctx.fwd_numpy_state,
225231
ctx.fwd_random_state,
232+
ctx.fwd_custom_state,
233+
ctx.custom_get_state_func,
234+
ctx.custom_set_state_func,
226235
):
227236
if ctx.is_fw_autocast:
228237
with paddle.amp.auto_cast(
@@ -307,9 +316,25 @@ def recompute_hybrid(
307316
if framework._dygraph_tracer()._has_grad:
308317
check_recompute_necessary(args)
309318

319+
if custom_state_manager.custom_get_state_func is None:
320+
assert custom_state_manager.custom_set_state_func is None
321+
custom_get_state_func = lambda x=None: None
322+
custom_set_state_func = lambda x=None: None
323+
else:
324+
custom_get_state_func = custom_state_manager.custom_get_state_func
325+
custom_set_state_func = custom_state_manager.custom_set_state_func
326+
310327
all_outputs = []
311328
_HPRecomputeFunction.apply(
312-
function, all_outputs, mp_group, offload, partition, *args, **kwargs
329+
function,
330+
all_outputs,
331+
mp_group,
332+
offload,
333+
partition,
334+
custom_get_state_func,
335+
custom_set_state_func,
336+
*args,
337+
**kwargs,
313338
)
314339

315340
if len(all_outputs) == 1:

test/collective/fleet/test_dygraph_recompute_for_eager.py

+125
Original file line numberDiff line numberDiff line change
@@ -413,5 +413,130 @@ def test_recompute_inputs_with_tuple(self):
413413
self.assertEqual(grad_ref, grad)
414414

415415

416+
class RandomManager:
417+
def __init__(self):
418+
self.global_random = random.Random(11)
419+
420+
custom_state_manager.set_custom_get_state_func(self.get_states)
421+
custom_state_manager.set_custom_set_state_func(self.set_states)
422+
423+
def get_states(self):
424+
return self.global_random.getstate()
425+
426+
def set_states(self, packed_states):
427+
self.global_random.setstate(packed_states)
428+
429+
430+
class RandomLayer(paddle.nn.Layer):
431+
def __init__(self, random_handle=None):
432+
super().__init__()
433+
self.random_handle = random_handle
434+
435+
def forward(self, input):
436+
if self.random_handle is not None:
437+
random_val = self.random_handle.global_random.random()
438+
else:
439+
random_val = random.random()
440+
return random_val * input
441+
442+
443+
from paddle.distributed.fleet.recompute import custom_state_manager
444+
445+
446+
class Random_fc_net(paddle.nn.Layer):
447+
def __init__(self, input_size, random_handle=None, use_recompute=False):
448+
super().__init__()
449+
self.use_recompute = use_recompute
450+
self.runfunc0 = get_fc_block(0, input_size, is_last=False)
451+
self.random0 = RandomLayer(random_handle)
452+
self.runfunc1 = get_fc_block(1, input_size, is_last=False)
453+
self.random1 = RandomLayer(random_handle)
454+
self.runfunc2 = get_fc_block(2, input_size, is_last=False)
455+
self.random2 = RandomLayer(random_handle)
456+
self.runfunc3 = get_fc_block(3, input_size, is_last=False)
457+
self.random3 = RandomLayer(random_handle)
458+
self.runfunc4 = get_fc_block(4, input_size, is_last=True)
459+
460+
def forward_impl(self, input):
461+
o = self.runfunc0(input)
462+
o = self.random0(o)
463+
o = self.runfunc1(o)
464+
o = self.random1(o)
465+
o = self.runfunc2(o)
466+
o = self.random2(o)
467+
o = self.runfunc3(o)
468+
o = self.random3(o)
469+
o = self.runfunc4(o)
470+
return o
471+
472+
def forward(self, input):
473+
if self.use_recompute:
474+
return recompute(self.forward_impl, input)
475+
else:
476+
return self.forward_impl(input)
477+
478+
479+
class TesRandomStatesInRecompute(unittest.TestCase):
480+
def run_model(self, batch_size, input_size, random_handle, use_recompute):
481+
gen = paddle.seed(10)
482+
gen.manual_seed(10)
483+
np.random.seed(10)
484+
random.seed(10)
485+
model = Random_fc_net(
486+
input_size=input_size,
487+
random_handle=random_handle,
488+
use_recompute=use_recompute,
489+
)
490+
loss_fn = paddle.nn.MSELoss(reduction='mean')
491+
optimizer = paddle.optimizer.SGD(
492+
learning_rate=0.01, parameters=model.parameters()
493+
)
494+
loss_ = []
495+
for step in range(10):
496+
x_data = np.random.randn(batch_size, input_size).astype(np.float32)
497+
x = paddle.to_tensor(x_data)
498+
x.stop_gradient = False
499+
y_pred = model(x)
500+
loss = y_pred.mean()
501+
loss_.append(np.asarray(loss).tolist())
502+
loss.backward()
503+
optimizer.step()
504+
optimizer.clear_grad()
505+
return loss_
506+
507+
def test_custom_random_manager(self):
508+
input_size = 1024
509+
batch_size = 2
510+
511+
random_handle = RandomManager()
512+
513+
random_handle.global_random.seed(42)
514+
loss_recompute = self.run_model(
515+
batch_size, input_size, random_handle, use_recompute=True
516+
)
517+
518+
random_handle.global_random.seed(42)
519+
loss_ref = self.run_model(
520+
batch_size, input_size, random_handle, use_recompute=False
521+
)
522+
523+
for loss1, loss2 in zip(loss_recompute, loss_ref):
524+
self.assertEqual(loss1, loss2)
525+
526+
def test_normal_random(self):
527+
input_size = 1024
528+
batch_size = 2
529+
530+
loss_recompute = self.run_model(
531+
batch_size, input_size, random_handle=None, use_recompute=True
532+
)
533+
loss_ref = self.run_model(
534+
batch_size, input_size, random_handle=None, use_recompute=False
535+
)
536+
537+
for loss1, loss2 in zip(loss_recompute, loss_ref):
538+
self.assertEqual(loss1, loss2)
539+
540+
416541
if __name__ == '__main__':
417542
unittest.main()

0 commit comments

Comments
 (0)