Skip to content

Commit 96399c3

Browse files
SunMarcsayakpaul
andauthored
Fix sharding when no device_map is passed (#8531)
* Fix sharding when no device_map is passed * style * add tests * align * add docstring * format --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 10d3220 commit 96399c3

File tree

3 files changed

+63
-4
lines changed

3 files changed

+63
-4
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
462462
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
463463
A map that specifies where each submodule should go. It doesn't need to be defined for each
464464
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
465-
same device.
465+
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
466466
467467
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
468468
more information about each option see [designing a device
@@ -774,7 +774,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
774774
else: # else let accelerate handle loading and dispatching.
775775
# Load weights and dispatch according to the device_map
776776
# by default the device_map is None and the weights are loaded on the CPU
777+
force_hook = True
777778
device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
779+
if device_map is None and is_sharded:
780+
# we load the parameters on the cpu
781+
device_map = {"": "cpu"}
782+
force_hook = False
778783
try:
779784
accelerate.load_checkpoint_and_dispatch(
780785
model,
@@ -784,7 +789,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
784789
offload_folder=offload_folder,
785790
offload_state_dict=offload_state_dict,
786791
dtype=torch_dtype,
787-
force_hooks=True,
792+
force_hooks=force_hook,
788793
strict=True,
789794
)
790795
except AttributeError as e:
@@ -808,12 +813,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
808813
model._temp_convert_self_to_deprecated_attention_blocks()
809814
accelerate.load_checkpoint_and_dispatch(
810815
model,
811-
model_file,
816+
model_file if not is_sharded else sharded_ckpt_cached_folder,
812817
device_map,
813818
max_memory=max_memory,
814819
offload_folder=offload_folder,
815820
offload_state_dict=offload_state_dict,
816821
dtype=torch_dtype,
822+
force_hook=force_hook,
823+
strict=True,
817824
)
818825
model._undo_temp_convert_self_to_deprecated_attention_blocks()
819826
else:

tests/models/test_modeling_common.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,39 @@ def test_model_parallelism(self):
872872

873873
@require_torch_gpu
874874
def test_sharded_checkpoints(self):
875+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
876+
model = self.model_class(**config).eval()
877+
model = model.to(torch_device)
878+
879+
torch.manual_seed(0)
880+
base_output = model(**inputs_dict)
881+
882+
model_size = compute_module_sizes(model)[""]
883+
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
884+
with tempfile.TemporaryDirectory() as tmp_dir:
885+
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
886+
self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
887+
888+
# Now check if the right number of shards exists. First, let's get the number of shards.
889+
# Since this number can be dependent on the model being tested, it's important that we calculate it
890+
# instead of hardcoding it.
891+
with open(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) as f:
892+
weight_map_dict = json.load(f)["weight_map"]
893+
first_key = list(weight_map_dict.keys())[0]
894+
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
895+
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
896+
897+
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
898+
self.assertTrue(actual_num_shards == expected_num_shards)
899+
900+
new_model = self.model_class.from_pretrained(tmp_dir)
901+
902+
torch.manual_seed(0)
903+
new_output = new_model(**inputs_dict)
904+
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
905+
906+
@require_torch_gpu
907+
def test_sharded_checkpoints_device_map(self):
875908
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
876909
model = self.model_class(**config).eval()
877910
if model._no_split_modules is None:

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1038,14 +1038,33 @@ def test_ip_adapter_plus(self):
10381038
@require_torch_gpu
10391039
def test_load_sharded_checkpoint_from_hub(self):
10401040
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1041-
loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy", device_map="auto")
1041+
loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy")
10421042
new_output = loaded_model(**inputs_dict)
10431043

10441044
assert loaded_model
10451045
assert new_output.sample.shape == (4, 4, 16, 16)
10461046

10471047
@require_torch_gpu
10481048
def test_load_sharded_checkpoint_from_hub_local(self):
1049+
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1050+
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
1051+
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
1052+
new_output = loaded_model(**inputs_dict)
1053+
1054+
assert loaded_model
1055+
assert new_output.sample.shape == (4, 4, 16, 16)
1056+
1057+
@require_torch_gpu
1058+
def test_load_sharded_checkpoint_device_map_from_hub(self):
1059+
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1060+
loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy", device_map="auto")
1061+
new_output = loaded_model(**inputs_dict)
1062+
1063+
assert loaded_model
1064+
assert new_output.sample.shape == (4, 4, 16, 16)
1065+
1066+
@require_torch_gpu
1067+
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
10491068
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10501069
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
10511070
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")

0 commit comments

Comments
 (0)