-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fix wrong behavior of DDPStrategy
option with simple GAN training using DDP
#20936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Fix wrong behavior of DDPStrategy
option with simple GAN training using DDP
#20936
Conversation
for more information, see https://pre-commit.ci
def block(in_feat, out_feat, normalize=True): | ||
layers = [nn.Linear(in_feat, out_feat)] | ||
if normalize: | ||
layers.append(nn.BatchNorm1d(out_feat, 0.8)) | ||
layers.append(nn.LeakyReLU(0.2, inplace=True)) | ||
return layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets move it out s a funtion
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 7 days if no further activity occurs. If you need further help see our docs: https://lightning.ai/docs/pytorch/latest/generated/CONTRIBUTING.html#pull-request or ask the assistance of a core contributor here or on Discord. Thank you for your contributions. |
…dp-implementation test: cover MultiModelDDPStrategy
examples/pytorch/domain_templates/generative_adversarial_net.py
Outdated
Show resolved
Hide resolved
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) | ||
return opt_g, opt_d | ||
|
||
# ! TESTING |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this comment mean?
examples/pytorch/domain_templates/generative_adversarial_net_ddp.py
Outdated
Show resolved
Hide resolved
@SkafteNicki could you pls check too :) |
@@ -419,6 +419,39 @@ def teardown(self) -> None: | |||
super().teardown() | |||
|
|||
|
|||
class MultiModelDDPStrategy(DDPStrategy): | |||
@override | |||
def _setup_model(self, model: Module) -> Module: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The typing is not very happy here as the parent class has the following footprint:
def _setup_model(self, model: Module) -> DistributedDataParallel:
dm = MNISTDataModule() | ||
trainer = Trainer( | ||
accelerator="auto", | ||
devices=[0, 1, 2, 3], | ||
strategy=MultiModelDDPStrategy(), | ||
max_epochs=100, | ||
) | ||
|
||
trainer.fit(model, dm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe i am reading it wrong, but is this the only core difference between this new _ddp
script and the single gpu example script?
In that case could we please just incorporate the changes into the other script and add whatever arguments is needed to the ArgumentParser
to distinguish between the two cases
@@ -419,6 +419,39 @@ def teardown(self) -> None: | |||
super().teardown() | |||
|
|||
|
|||
class MultiModelDDPStrategy(DDPStrategy): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add docstring to class. What is the purpose of the class and when to use it compared to the standard DDPStrategy
with mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook: | ||
strategy._register_ddp_hooks() | ||
|
||
register_hook.assert_not_called() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we add one or two tests similar to
https://github.com/samsara-ku/pytorch-lightning/blob/ece7d38abb908e96eb72af363606d2585e8843f2/tests/tests_pytorch/strategies/test_ddp_integration.py#L40-L66
that actually tests that the new strategy works with trainer.fit
as expected
What does this PR do?
Fixes #20866 #20328 #18740 #17212
This PR adds
MultiModelDDPStrategy
class and its simple execution example, for the multi-gpu training with GAN training.Simply speaking:
Currently, pytorch lightning simple GAN training has has problem with
DistributedDataParallel
strategy. It tries to wrappl.trainer
, not thenn.Module
models in thepl.trainer
Although we can activate
find_unused_parameters=True
options to avoid this issue but it is not right way; I think it is just a trick.So the key idea to solve this issue is that we assign
DistributedDataParallel
to the each model in thepl.trainer
, different from the previous strategyDDPStrategy
.I already tested with my GPUs to visaulize the result and tracked the gradients of model with each epoch; it works and you can see the visulized result in thie google drive link
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--20936.org.readthedocs.build/en/20936/