diff --git a/docs/source-pytorch/advanced/model_parallel/deepspeed.rst b/docs/source-pytorch/advanced/model_parallel/deepspeed.rst index 3a3846500ff35..dbec772ac0d40 100644 --- a/docs/source-pytorch/advanced/model_parallel/deepspeed.rst +++ b/docs/source-pytorch/advanced/model_parallel/deepspeed.rst @@ -408,6 +408,7 @@ Here is some helpful information when setting up DeepSpeed ZeRO Stage 3 with Lig * Treat your GPU/CPU memory as one large pool. In some cases, you may not want to offload certain things (like activations) to provide even more space to offload model parameters * When offloading to the CPU, make sure to bump up the batch size as GPU memory will be freed * We also support sharded checkpointing. By passing ``save_full_weights=False`` to the ``DeepSpeedStrategy``, we'll save shards of the model which allows you to save extremely large models. However to load the model and run test/validation/predict you must use the Trainer object. +* DeepSpeed provides `MiCS support `_ which allows you to control how model parameters are sharded across GPUs. For example, with 16 GPUs, ZeRO-3 will shard the model into 16 pieces by default. Instead with ``mics_shard_size=8``, every 8 GPUs will keep a full copy of the model weights, reducing the communication overhead. You can set ``"zero_optimization": {"stage": 3, "mics_shard_size": (shards num), ...}`` in a DeepSpeed config file to take advantage of this feature. .. _deepspeed-zero-stage-3-single-file: diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 48333455240cf..c4fd58dde2c98 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -376,6 +376,19 @@ def module_sharded_context(self) -> AbstractContextManager: import deepspeed assert self._config_initialized + assert self.config is not None + + if ( + "zero_optimization" in self.config + and "mics_shard_size" in self.config["zero_optimization"] + and self.config["zero_optimization"]["mics_shard_size"] > 0 + and self.zero_stage_3 + ): + return deepspeed.zero.MiCS_Init( + enabled=self.zero_stage_3, + remote_device=self.remote_device, + config_dict_or_path=self.config, + ) return deepspeed.zero.Init( enabled=self.zero_stage_3, remote_device=self.remote_device, diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index dabfde70242b9..000165bb3a401 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -526,12 +526,29 @@ def model_sharded_context(self) -> Generator[None, None, None]: import deepspeed self._init_config_if_needed() - with deepspeed.zero.Init( - enabled=self.zero_stage_3, - remote_device=self.remote_device, - config_dict_or_path=self.config, + assert self.config is not None + # If we detect `'mics_shard_size' > 0` in `config['zero_optimization']`, use `deepspeed.zero.MiCS_Init(...)` instead of `deepspeed.zero.Init(...)` + # https://deepspeed.readthedocs.io/en/latest/zero3.html#mics-configurations + #! default deepspeed 0.9.0 is not compatible + if ( + "zero_optimization" in self.config + and "mics_shard_size" in self.config["zero_optimization"] + and self.config["zero_optimization"]["mics_shard_size"] > 0 + and self.zero_stage_3 ): - yield + with deepspeed.zero.MiCS_Init( + enabled=self.zero_stage_3, + remote_device=self.remote_device, + config_dict_or_path=self.config, + ): + yield + else: + with deepspeed.zero.Init( + enabled=self.zero_stage_3, + remote_device=self.remote_device, + config_dict_or_path=self.config, + ): + yield def _set_deepspeed_activation_checkpointing(self) -> None: import deepspeed diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index 5970b673cee5f..e5b70d5511cef 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -414,3 +414,150 @@ def test_deepspeed_init_module_with_stages_1_2(stage, empty_init): zero_init_mock.assert_called_with(enabled=False, remote_device=None, config_dict_or_path=ANY) assert init_mock.call_count == int(not empty_init) assert model.layer.weight.dtype == torch.bfloat16 + + +@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) +def test_deepspeed_multigpu_stage_3_MiCS_support(): + """Test to ensure ZeRO Stage 3 MiCS works with a parallel model.""" + strategy = DeepSpeedStrategy(stage=3) + strategy.config["zero_optimization"]["stage"] = 3 + strategy.config["zero_optimization"]["mics_shard_size"] = 1 + strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False + + fabric = Fabric( + strategy=strategy, + accelerator="cuda", + devices=2, + precision="16-mixed", + ) + fabric.launch() + + def _make_block(): + return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU()) + + with fabric.init_module(): + model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3)) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + model, optimizer = fabric.setup(model, optimizer) + + x = torch.rand(2, 32, device=fabric.device) + y = torch.ones(x.size(0), device=x.device, dtype=torch.long) + x = model(x) + x = x.float() # Ensure output is in float32 for softmax operation + logits = F.softmax(x, dim=1) + loss = F.cross_entropy(logits, y) + fabric.backward(loss) + optimizer.step() + optimizer.zero_grad() + + +@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) +def test_deepspeed_multigpu_stage_3_MiCS_offload_param_support(): + """Test to ensure we can use DeepSpeed with ZeRO Stage param offload 3 MiCS Support.""" + strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu") + strategy.config["zero_optimization"]["stage"] = 3 + strategy.config["zero_optimization"]["mics_shard_size"] = 1 + strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False + + fabric = Fabric( + strategy=strategy, + accelerator="cuda", + devices=2, + precision="16-mixed", + ) + fabric.launch() + + def _make_block(): + return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU()) + + with fabric.init_module(): + model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3)) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + model, optimizer = fabric.setup(model, optimizer) + + x = torch.rand(2, 32, device=fabric.device) + y = torch.ones(x.size(0), device=x.device, dtype=torch.long) + x = model(x) + x = x.float() # Ensure output is in float32 for softmax operation + logits = F.softmax(x, dim=1) + loss = F.cross_entropy(logits, y) + fabric.backward(loss) + optimizer.step() + optimizer.zero_grad() + + +@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) +def test_deepspeed_multigpu_stage_3_MiCS_offload_param_optimizer_support(): + """Test to ensure we can use DeepSpeed with ZeRO Stage param & optimizer offload 3 MiCS Support.""" + strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu", offload_optimizer_device="cpu") + strategy.config["zero_optimization"]["stage"] = 3 + strategy.config["zero_optimization"]["mics_shard_size"] = 1 + strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False + + fabric = Fabric( + strategy=strategy, + accelerator="cuda", + devices=2, + precision="16-mixed", + ) + fabric.launch() + + def _make_block(): + return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU()) + + with fabric.init_module(): + model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3)) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + model, optimizer = fabric.setup(model, optimizer) + + x = torch.rand(2, 32, device=fabric.device) + y = torch.ones(x.size(0), device=x.device, dtype=torch.long) + x = model(x) + x = x.float() # Ensure output is in float32 for softmax operation + logits = F.softmax(x, dim=1) + loss = F.cross_entropy(logits, y) + fabric.backward(loss) + optimizer.step() + optimizer.zero_grad() + + +@RunIf(min_cuda_gpus=4, standalone=True, deepspeed=True) +def test_deepspeed_multigpu_stage_3_hierarchical_MiCS_support(): + """Test to ensure we can use DeepSpeed with ZeRO Stage 3 MiCS Support ('mics_hierarchical_params_gather' = + True).""" + strategy = DeepSpeedStrategy(stage=3) + strategy.config["zero_optimization"]["stage"] = 3 + strategy.config["zero_optimization"]["mics_shard_size"] = 2 + strategy.config["zero_optimization"]["offload_param"] = {} + strategy.config["zero_optimization"]["offload_optimizer"] = {} + strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = True + + fabric = Fabric( + strategy=strategy, + accelerator="cuda", + devices=2, + precision="16-mixed", + ) + fabric.launch() + + def _make_block(): + return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU()) + + with fabric.init_module(): + model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3)) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + model, optimizer = fabric.setup(model, optimizer) + + x = torch.rand(2, 32, device=fabric.device) + y = torch.ones(x.size(0), device=x.device, dtype=torch.long) + x = model(x) + x = x.float() # Ensure output is in float32 for softmax operation + logits = F.softmax(x, dim=1) + loss = F.cross_entropy(logits, y) + fabric.backward(loss) + optimizer.step() + optimizer.zero_grad() diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py index 7e7d2eacd0617..705c2b673d00d 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed.py +++ b/tests/tests_pytorch/strategies/test_deepspeed.py @@ -1279,3 +1279,126 @@ def test_deepspeed_load_checkpoint_validate_path(tmp_path): checkpoint_path.touch() with pytest.raises(FileNotFoundError, match=f"Try to load using this parent directory instead: {tmp_path}"): strategy.load_checkpoint(checkpoint_path=checkpoint_path) + + +@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) +def test_deepspeed_multigpu_stage_3_MiCS_support(tmp_path): + """Test to ensure we can use DeepSpeed with basic ZeRO Stage 3 MiCS Support.""" + model = ModelParallelBoringModel() + strategy = DeepSpeedStrategy(stage=3) + strategy.config["zero_optimization"]["stage"] = 3 + strategy.config["zero_optimization"]["mics_shard_size"] = 1 + strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False + + trainer = Trainer( + default_root_dir=tmp_path, + strategy=strategy, + accelerator="gpu", + devices=2, + fast_dev_run=True, + precision="16-mixed", + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.test(model) + trainer.fit(model) + + _assert_save_model_is_equal(model, tmp_path, trainer) + assert isinstance(trainer.strategy, DeepSpeedStrategy) + assert "zero_optimization" in trainer.strategy.config + assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] is False + assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 1 + assert trainer.strategy.config["zero_optimization"]["stage"] == 3 + + +@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) +def test_deepspeed_multigpu_stage_3_MiCS_offload_param_support(tmp_path): + """Test to ensure we can use DeepSpeed with ZeRO Stage param offload 3 MiCS Support \ + However, in some past pratice, offload param + mics + torchrun will cause inner exception in multi-node environment. \ + Probably this exception is caused by torchrun, not deepspeed. """ + model = ModelParallelBoringModel() + strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu") + strategy.config["zero_optimization"]["stage"] = 3 + strategy.config["zero_optimization"]["mics_shard_size"] = 1 + strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False + trainer = Trainer( + default_root_dir=tmp_path, + strategy=strategy, + accelerator="gpu", + devices=2, + fast_dev_run=True, + precision="16-mixed", + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.test(model) + trainer.fit(model) + + _assert_save_model_is_equal(model, tmp_path, trainer) + assert isinstance(trainer.strategy, DeepSpeedStrategy) + assert "zero_optimization" in trainer.strategy.config + assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] is False + assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 1 + assert trainer.strategy.config["zero_optimization"]["stage"] == 3 + + +@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) +def test_deepspeed_multigpu_stage_3_MiCS_offload_param_optimizer_support(tmp_path): + """Test to ensure we can use DeepSpeed with ZeRO Stage param & optimizer offload 3 MiCS Support.""" + model = ModelParallelBoringModel() + strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu", offload_optimizer_device="cpu") + strategy.config["zero_optimization"]["stage"] = 3 + strategy.config["zero_optimization"]["mics_shard_size"] = 1 + strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False + trainer = Trainer( + default_root_dir=tmp_path, + strategy=strategy, + accelerator="gpu", + devices=2, + fast_dev_run=True, + precision="16-mixed", + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.test(model) + trainer.fit(model) + + _assert_save_model_is_equal(model, tmp_path, trainer) + assert isinstance(trainer.strategy, DeepSpeedStrategy) + assert "zero_optimization" in trainer.strategy.config + assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] is False + assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 1 + assert trainer.strategy.config["zero_optimization"]["stage"] == 3 + + +@RunIf(min_cuda_gpus=4, standalone=True, deepspeed=True) +def test_deepspeed_multigpu_stage_3_hierarchical_MiCS_support(tmp_path): + """Test to ensure we can use DeepSpeed with ZeRO Stage 3 MiCS Support ('mics_hierarchical_params_gather' = + True).""" + model = ModelParallelBoringModel() + strategy = DeepSpeedStrategy(stage=3) + strategy.config["zero_optimization"]["stage"] = 3 + strategy.config["zero_optimization"]["mics_shard_size"] = 2 + strategy.config["zero_optimization"]["offload_param"] = {} + strategy.config["zero_optimization"]["offload_optimizer"] = {} + strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = True + # Forming a 2 x 2 hierarchy + trainer = Trainer( + default_root_dir=tmp_path, + strategy=strategy, + accelerator="gpu", + devices=4, + fast_dev_run=True, + precision="16-mixed", + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.test(model) + trainer.fit(model) + + _assert_save_model_is_equal(model, tmp_path, trainer) + assert isinstance(trainer.strategy, DeepSpeedStrategy) + assert "zero_optimization" in trainer.strategy.config + assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] is True + assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 2 + assert trainer.strategy.config["zero_optimization"]["stage"] == 3