Skip to content

[Distributed] Support custom state setter and getter for recompute to save custom states #72670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/paddle/distributed/fleet/recompute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .recompute import recompute, recompute_sequential # noqa: F401
from .recompute import ( # noqa: F401
custom_state_manager,
recompute,
recompute_sequential,
)
from .recompute_hybrid import recompute_hybrid # noqa: F401

__all__ = []
95 changes: 90 additions & 5 deletions python/paddle/distributed/fleet/recompute/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,42 @@ def check_recompute_necessary(inputs):
)


class CustomStatesManager:
"""CustomStatesManager"""

def __init__(self):
"""__init__"""
self.custom_get_state_func = None
self.custom_set_state_func = None

def set_custom_get_state_func(self, custom_get_state_func):
assert_msg = (
"The custom_state_manager does not support duplicate settings."
)
assert self.custom_get_state_func is None, assert_msg
self.custom_get_state_func = custom_get_state_func

def set_custom_set_state_func(self, custom_set_state_func):
assert_msg = (
"The custom_state_manager does not support duplicate settings."
)
assert self.custom_set_state_func is None, assert_msg
self.custom_set_state_func = custom_set_state_func


custom_state_manager = CustomStatesManager()


@contextlib.contextmanager
def switch_rng_state_tracker(rng_state, tracker, numpy_state, random_state):
def switch_rng_state_tracker(
rng_state,
tracker,
numpy_state,
random_state,
custom_state=None,
custom_get_state_func=None,
custom_set_state_func=None,
):
orig_rng_state = paddle.get_rng_state()
orig_rng_tracker = get_rng_state_tracker().get_states_tracker()
paddle.set_rng_state(rng_state)
Expand All @@ -126,6 +160,12 @@ def switch_rng_state_tracker(rng_state, tracker, numpy_state, random_state):
orig_random_state = random.getstate()
np.random.set_state(numpy_state)
random.setstate(random_state)

if custom_state is not None:
assert custom_get_state_func is not None
assert custom_set_state_func is not None
orig_custom_state = custom_get_state_func()
custom_set_state_func(custom_state)
try:
yield
finally:
Expand All @@ -134,11 +174,21 @@ def switch_rng_state_tracker(rng_state, tracker, numpy_state, random_state):
np.random.set_state(orig_numpy_state)
random.setstate(orig_random_state)

if custom_state is not None:
custom_set_state_func(orig_custom_state)


class RecomputeFunction(PyLayer):
@staticmethod
def forward(
ctx, run_function, preserve_rng_state, offload_indices, *args, **kwargs
ctx,
run_function,
preserve_rng_state,
offload_indices,
custom_get_state_func,
custom_set_state_func,
*args,
**kwargs,
):
# store for recomputing
ctx.run_function = run_function
Expand All @@ -159,6 +209,9 @@ def forward(
)
ctx.fwd_numpy_state = np.random.get_state()
ctx.fwd_random_state = random.getstate()
ctx.fwd_custom_state = custom_get_state_func()
ctx.custom_get_state_func = custom_get_state_func
ctx.custom_set_state_func = custom_set_state_func

# TODO support AMP
tracer = framework._dygraph_tracer()
Expand Down Expand Up @@ -268,6 +321,9 @@ def backward(ctx, *args):
ctx.fwd_rng_state_tracker,
ctx.fwd_numpy_state,
ctx.fwd_random_state,
ctx.fwd_custom_state,
ctx.custom_get_state_func,
ctx.custom_set_state_func,
):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
Expand Down Expand Up @@ -342,7 +398,12 @@ def backward(ctx, *args):


def _recompute_without_reentrant(
function, preserve_rng_state=True, *args, **kwargs
function,
custom_get_state_func,
custom_set_state_func,
preserve_rng_state=True,
*args,
**kwargs,
):
"""
recompute without reentrant, that means use hook to implement the recompute function rather than re-entrant autograd.
Expand Down Expand Up @@ -370,6 +431,7 @@ def _recompute_without_reentrant(
)
fwd_numpy_state = np.random.get_state()
fwd_random_state = random.getstate()
fwd_custom_state = custom_get_state_func()

tracer = framework._dygraph_tracer()
is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
Expand Down Expand Up @@ -445,6 +507,9 @@ def inner_unpack(inner_x):
fwd_cuda_rng_state_tracker,
fwd_numpy_state,
fwd_random_state,
fwd_custom_state,
custom_get_state_func,
custom_set_state_func,
):
with paddle.set_grad_enabled(True):
with paddle.amp.auto_cast(
Expand Down Expand Up @@ -600,6 +665,14 @@ def recompute(function, *args, **kwargs):
# whether to use reentrant method to implement recompute
use_reentrant = kwargs.pop('use_reentrant', True)

if custom_state_manager.custom_get_state_func is None:
assert custom_state_manager.custom_set_state_func is None
custom_get_state_func = lambda x=None: None
custom_set_state_func = lambda x=None: None
else:
custom_get_state_func = custom_state_manager.custom_get_state_func
custom_set_state_func = custom_state_manager.custom_set_state_func

if not in_dynamic_mode():
from paddle.distributed.auto_parallel.interface import (
recompute as static_auto_recompute,
Expand Down Expand Up @@ -644,10 +717,22 @@ def recompute(function, *args, **kwargs):
raise ValueError("Unknown parameter kind.")

return RecomputeFunction.apply(
function, preserve, offload_indices, *input_args
function,
preserve,
offload_indices,
custom_get_state_func,
custom_set_state_func,
*input_args,
)
else:
return _recompute_without_reentrant(function, preserve, *args, **kwargs)
return _recompute_without_reentrant(
function,
custom_get_state_func,
custom_set_state_func,
preserve,
*args,
**kwargs,
)


def recompute_sequential(
Expand Down
27 changes: 26 additions & 1 deletion python/paddle/distributed/fleet/recompute/recompute_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ..meta_parallel.pp_utils import utils
from .recompute import (
check_recompute_necessary,
custom_state_manager,
detach_variable,
switch_rng_state_tracker,
)
Expand Down Expand Up @@ -101,6 +102,8 @@ def forward(
mp_group,
offload,
partition,
custom_get_state_func,
custom_set_state_func,
*args,
**kwargs,
):
Expand All @@ -114,6 +117,9 @@ def forward(
ctx.fwd_rng_state_tracker = get_rng_state_tracker().get_states_tracker()
ctx.fwd_numpy_state = np.random.get_state()
ctx.fwd_random_state = random.getstate()
ctx.fwd_custom_state = custom_get_state_func()
ctx.custom_get_state_func = custom_get_state_func
ctx.custom_set_state_func = custom_set_state_func

# save config info
ctx.mp_group = mp_group
Expand Down Expand Up @@ -223,6 +229,9 @@ def backward(ctx, *args):
ctx.fwd_rng_state_tracker,
ctx.fwd_numpy_state,
ctx.fwd_random_state,
ctx.fwd_custom_state,
ctx.custom_get_state_func,
ctx.custom_set_state_func,
):
if ctx.is_fw_autocast:
with paddle.amp.auto_cast(
Expand Down Expand Up @@ -307,9 +316,25 @@ def recompute_hybrid(
if framework._dygraph_tracer()._has_grad:
check_recompute_necessary(args)

if custom_state_manager.custom_get_state_func is None:
assert custom_state_manager.custom_set_state_func is None
custom_get_state_func = lambda x=None: None
custom_set_state_func = lambda x=None: None
else:
custom_get_state_func = custom_state_manager.custom_get_state_func
custom_set_state_func = custom_state_manager.custom_set_state_func

all_outputs = []
_HPRecomputeFunction.apply(
function, all_outputs, mp_group, offload, partition, *args, **kwargs
function,
all_outputs,
mp_group,
offload,
partition,
custom_get_state_func,
custom_set_state_func,
*args,
**kwargs,
)

if len(all_outputs) == 1:
Expand Down
125 changes: 125 additions & 0 deletions test/collective/fleet/test_dygraph_recompute_for_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,5 +413,130 @@ def test_recompute_inputs_with_tuple(self):
self.assertEqual(grad_ref, grad)


class RandomManager:
def __init__(self):
self.global_random = random.Random(11)

custom_state_manager.set_custom_get_state_func(self.get_states)
custom_state_manager.set_custom_set_state_func(self.set_states)

def get_states(self):
return self.global_random.getstate()

def set_states(self, packed_states):
self.global_random.setstate(packed_states)


class RandomLayer(paddle.nn.Layer):
def __init__(self, random_handle=None):
super().__init__()
self.random_handle = random_handle

def forward(self, input):
if self.random_handle is not None:
random_val = self.random_handle.global_random.random()
else:
random_val = random.random()
return random_val * input


from paddle.distributed.fleet.recompute import custom_state_manager


class Random_fc_net(paddle.nn.Layer):
def __init__(self, input_size, random_handle=None, use_recompute=False):
super().__init__()
self.use_recompute = use_recompute
self.runfunc0 = get_fc_block(0, input_size, is_last=False)
self.random0 = RandomLayer(random_handle)
self.runfunc1 = get_fc_block(1, input_size, is_last=False)
self.random1 = RandomLayer(random_handle)
self.runfunc2 = get_fc_block(2, input_size, is_last=False)
self.random2 = RandomLayer(random_handle)
self.runfunc3 = get_fc_block(3, input_size, is_last=False)
self.random3 = RandomLayer(random_handle)
self.runfunc4 = get_fc_block(4, input_size, is_last=True)

def forward_impl(self, input):
o = self.runfunc0(input)
o = self.random0(o)
o = self.runfunc1(o)
o = self.random1(o)
o = self.runfunc2(o)
o = self.random2(o)
o = self.runfunc3(o)
o = self.random3(o)
o = self.runfunc4(o)
return o

def forward(self, input):
if self.use_recompute:
return recompute(self.forward_impl, input)
else:
return self.forward_impl(input)


class TesRandomStatesInRecompute(unittest.TestCase):
def run_model(self, batch_size, input_size, random_handle, use_recompute):
gen = paddle.seed(10)
gen.manual_seed(10)
np.random.seed(10)
random.seed(10)
model = Random_fc_net(
input_size=input_size,
random_handle=random_handle,
use_recompute=use_recompute,
)
loss_fn = paddle.nn.MSELoss(reduction='mean')
optimizer = paddle.optimizer.SGD(
learning_rate=0.01, parameters=model.parameters()
)
loss_ = []
for step in range(10):
x_data = np.random.randn(batch_size, input_size).astype(np.float32)
x = paddle.to_tensor(x_data)
x.stop_gradient = False
y_pred = model(x)
loss = y_pred.mean()
loss_.append(np.asarray(loss).tolist())
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss_

def test_custom_random_manager(self):
input_size = 1024
batch_size = 2

random_handle = RandomManager()

random_handle.global_random.seed(42)
loss_recompute = self.run_model(
batch_size, input_size, random_handle, use_recompute=True
)

random_handle.global_random.seed(42)
loss_ref = self.run_model(
batch_size, input_size, random_handle, use_recompute=False
)

for loss1, loss2 in zip(loss_recompute, loss_ref):
self.assertEqual(loss1, loss2)

def test_normal_random(self):
input_size = 1024
batch_size = 2

loss_recompute = self.run_model(
batch_size, input_size, random_handle=None, use_recompute=True
)
loss_ref = self.run_model(
batch_size, input_size, random_handle=None, use_recompute=False
)

for loss1, loss2 in zip(loss_recompute, loss_ref):
self.assertEqual(loss1, loss2)


if __name__ == '__main__':
unittest.main()
Loading