Skip to content

Commit 9e9ed35

Browse files
yiyixuxusayakpaul
andauthored
fix loading sharded checkpoints from subfolder (#8798)
* fix load sharded checkpoints from subfolder{ * style * os.path.join * add a small test --------- Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
1 parent 7833ed9 commit 9e9ed35

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def _fetch_index_file(
221221
local_files_only=local_files_only,
222222
token=token,
223223
revision=revision,
224-
subfolder=subfolder,
224+
subfolder=None,
225225
user_agent=user_agent,
226226
commit_hash=commit_hash,
227227
)

src/diffusers/utils/hub_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,10 +455,13 @@ def _get_checkpoint_shard_files(
455455

456456
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
457457
allow_patterns = original_shard_filenames
458+
if subfolder is not None:
459+
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
460+
458461
ignore_patterns = ["*.json", "*.md"]
459462
if not local_files_only:
460463
# `model_info` call must guarded with the above condition.
461-
model_files_info = model_info(pretrained_model_name_or_path)
464+
model_files_info = model_info(pretrained_model_name_or_path, revision=revision)
462465
for shard_file in original_shard_filenames:
463466
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
464467
if not shard_file_present:
@@ -481,6 +484,8 @@ def _get_checkpoint_shard_files(
481484
ignore_patterns=ignore_patterns,
482485
user_agent=user_agent,
483486
)
487+
if subfolder is not None:
488+
cached_folder = os.path.join(cached_folder, subfolder)
484489

485490
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
486491
# we don't have to catch them here. We have also dealt with EntryNotFoundError.

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,18 @@ def test_load_sharded_checkpoint_from_hub(self):
10451045
assert loaded_model
10461046
assert new_output.sample.shape == (4, 4, 16, 16)
10471047

1048+
@require_torch_gpu
1049+
def test_load_sharded_checkpoint_from_hub_subfolder(self):
1050+
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1051+
loaded_model = self.model_class.from_pretrained(
1052+
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet"
1053+
)
1054+
loaded_model = loaded_model.to(torch_device)
1055+
new_output = loaded_model(**inputs_dict)
1056+
1057+
assert loaded_model
1058+
assert new_output.sample.shape == (4, 4, 16, 16)
1059+
10481060
@require_torch_gpu
10491061
def test_load_sharded_checkpoint_from_hub_local(self):
10501062
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()

0 commit comments

Comments
 (0)