Skip to content

Commit d27b654

Browse files
committed
add more docstrings + experimental marks
1 parent cb9dca5 commit d27b654

File tree

5 files changed

+416
-160
lines changed

5 files changed

+416
-160
lines changed

src/diffusers/modular_pipelines/components_manager.py

Lines changed: 175 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def custom_offload_with_hook(
142142
user_hook.attach()
143143
return user_hook
144144

145-
145+
# this is the class that user can customize to implement their own offload strategy
146146
class AutoOffloadStrategy:
147147
"""
148148
Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on
@@ -213,7 +213,101 @@ def search_best_candidate(module_sizes, min_memory_offload):
213213
return hooks_to_offload
214214

215215

216+
# utils for display component info in a readable format
217+
# TODO: move to a different file
218+
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
219+
"""Summarizes a dictionary by finding common prefixes that share the same value.
220+
221+
For a dictionary with dot-separated keys like: {
222+
'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
223+
'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
224+
'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
225+
}
226+
227+
Returns a dictionary where keys are the shortest common prefixes and values are their shared values: {
228+
'down_blocks': [0.6], 'up_blocks': [0.3]
229+
}
230+
"""
231+
# First group by values - convert lists to tuples to make them hashable
232+
value_to_keys = {}
233+
for key, value in d.items():
234+
value_tuple = tuple(value) if isinstance(value, list) else value
235+
if value_tuple not in value_to_keys:
236+
value_to_keys[value_tuple] = []
237+
value_to_keys[value_tuple].append(key)
238+
239+
def find_common_prefix(keys: List[str]) -> str:
240+
"""Find the shortest common prefix among a list of dot-separated keys."""
241+
if not keys:
242+
return ""
243+
if len(keys) == 1:
244+
return keys[0]
245+
246+
# Split all keys into parts
247+
key_parts = [k.split(".") for k in keys]
248+
249+
# Find how many initial parts are common
250+
common_length = 0
251+
for parts in zip(*key_parts):
252+
if len(set(parts)) == 1: # All parts at this position are the same
253+
common_length += 1
254+
else:
255+
break
256+
257+
if common_length == 0:
258+
return ""
259+
260+
# Return the common prefix
261+
return ".".join(key_parts[0][:common_length])
262+
263+
# Create summary by finding common prefixes for each value group
264+
summary = {}
265+
for value_tuple, keys in value_to_keys.items():
266+
prefix = find_common_prefix(keys)
267+
if prefix: # Only add if we found a common prefix
268+
# Convert tuple back to list if it was originally a list
269+
value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple
270+
summary[prefix] = value
271+
else:
272+
summary[""] = value # Use empty string if no common prefix
273+
274+
return summary
275+
276+
216277
class ComponentsManager:
278+
"""
279+
A central registry and management system for model components across multiple pipelines.
280+
281+
[`ComponentsManager`] provides a unified way to register, track, and reuse model components
282+
(like UNet, VAE, text encoders, etc.) across different modular pipelines. It includes
283+
features for duplicate detection, memory management, and component organization.
284+
285+
<Tip warning={true}>
286+
287+
This is an experimental feature and is likely to change in the future.
288+
289+
</Tip>
290+
291+
Example:
292+
```python
293+
from diffusers import ComponentsManager
294+
295+
# Create a components manager
296+
cm = ComponentsManager()
297+
298+
# Add components
299+
cm.add("unet", unet_model, collection="sdxl")
300+
cm.add("vae", vae_model, collection="sdxl")
301+
302+
# Enable auto offloading
303+
cm.enable_auto_cpu_offload(device="cuda")
304+
305+
# Retrieve components
306+
unet = cm.get_one(name="unet", collection="sdxl")
307+
```
308+
"""
309+
310+
217311
_available_info_fields = [
218312
"model_id",
219313
"added_time",
@@ -278,7 +372,19 @@ def _lookup_ids(
278372
def _id_to_name(component_id: str):
279373
return "_".join(component_id.split("_")[:-1])
280374

281-
def add(self, name, component, collection: Optional[str] = None):
375+
def add(self, name: str, component: Any, collection: Optional[str] = None):
376+
"""
377+
Add a component to the ComponentsManager.
378+
379+
Args:
380+
name (str): The name of the component
381+
component (Any): The component to add
382+
collection (Optional[str]): The collection to add the component to
383+
384+
Returns:
385+
str: The unique component ID, which is generated as "{name}_{id(component)}" where
386+
id(component) is Python's built-in unique identifier for the object
387+
"""
282388
component_id = f"{name}_{id(component)}"
283389

284390
# check for duplicated components
@@ -334,6 +440,12 @@ def add(self, name, component, collection: Optional[str] = None):
334440
return component_id
335441

336442
def remove(self, component_id: str = None):
443+
"""
444+
Remove a component from the ComponentsManager.
445+
446+
Args:
447+
component_id (str): The ID of the component to remove
448+
"""
337449
if component_id not in self.components:
338450
logger.warning(f"Component '{component_id}' not found in ComponentsManager")
339451
return
@@ -545,6 +657,22 @@ def matches_pattern(component_id, pattern, exact_match=False):
545657
return get_return_dict(matches, return_dict_with_names)
546658

547659
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"):
660+
"""
661+
Enable automatic CPU offloading for all components.
662+
663+
The algorithm works as follows:
664+
1. All models start on CPU by default
665+
2. When a model's forward pass is called, it's moved to the execution device
666+
3. If there's insufficient memory, other models on the device are moved back to CPU
667+
4. The system tries to offload the smallest combination of models that frees enough memory
668+
5. Models stay on the execution device until another model needs memory and forces them off
669+
670+
Args:
671+
device (Union[str, int, torch.device]): The execution device where models are moved for forward passes
672+
memory_reserve_margin (str): The memory reserve margin to use, default is 3GB. This is the amount of
673+
memory to keep free on the device to avoid running out of memory during
674+
model execution (e.g., for intermediate activations, gradients, etc.)
675+
"""
548676
if not is_accelerate_available():
549677
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
550678

@@ -574,6 +702,9 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda"
574702
self._auto_offload_device = device
575703

576704
def disable_auto_cpu_offload(self):
705+
"""
706+
Disable automatic CPU offloading for all components.
707+
"""
577708
if self.model_hooks is None:
578709
self._auto_offload_enabled = False
579710
return
@@ -595,13 +726,12 @@ def get_model_info(
595726
"""Get comprehensive information about a component.
596727
597728
Args:
598-
component_id: Name of the component to get info for
599-
fields: Optional field(s) to return. Can be a string for single field or list of fields.
729+
component_id (str): Name of the component to get info for
730+
fields (Optional[Union[str, List[str]]]): Field(s) to return. Can be a string for single field or list of fields.
600731
If None, uses the available_info_fields setting.
601732
602733
Returns:
603-
Dictionary containing requested component metadata. If fields is specified, returns only those fields. If a
604-
single field is requested as string, returns just that field's value.
734+
Dictionary containing requested component metadata. If fields is specified, returns only those fields. Otherwise, returns all fields.
605735
"""
606736
if component_id not in self.components:
607737
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
@@ -808,15 +938,16 @@ def get_one(
808938
load_id: Optional[str] = None,
809939
) -> Any:
810940
"""
811-
Get a single component by either: (1) searching name (pattern matching), collection, or load_id. (2) passing in
812-
a component_id Raises an error if multiple components match or none are found. support pattern matching for
813-
name
941+
Get a single component by either:
942+
- searching name (pattern matching), collection, or load_id.
943+
- passing in a component_id
944+
Raises an error if multiple components match or none are found.
814945
815946
Args:
816-
component_id: Optional component ID to get
817-
name: Component name or pattern
818-
collection: Optional collection to filter by
819-
load_id: Optional load_id to filter by
947+
component_id (Optional[str]): Optional component ID to get
948+
name (Optional[str]): Component name or pattern
949+
collection (Optional[str]): Optional collection to filter by
950+
load_id (Optional[str]): Optional load_id to filter by
820951
821952
Returns:
822953
A single component
@@ -847,6 +978,13 @@ def get_one(
847978
def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] = None):
848979
"""
849980
Get component IDs by a list of names, optionally filtered by collection.
981+
982+
Args:
983+
names (Union[str, List[str]]): List of component names
984+
collection (Optional[str]): Optional collection to filter by
985+
986+
Returns:
987+
List[str]: List of component IDs
850988
"""
851989
ids = set()
852990
if not isinstance(names, list):
@@ -858,6 +996,20 @@ def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str]
858996
def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional[bool] = True):
859997
"""
860998
Get components by a list of IDs.
999+
1000+
Args:
1001+
ids (List[str]):
1002+
List of component IDs
1003+
return_dict_with_names (Optional[bool]):
1004+
Whether to return a dictionary with component names as keys:
1005+
1006+
Returns:
1007+
Dict[str, Any]: Dictionary of components.
1008+
- If return_dict_with_names=True, keys are component names.
1009+
- If return_dict_with_names=False, keys are component IDs.
1010+
1011+
Raises:
1012+
ValueError: If duplicate component names are found in the search results when return_dict_with_names=True
8611013
"""
8621014
components = {id: self.components[id] for id in ids}
8631015

@@ -877,65 +1029,17 @@ def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional
8771029
def get_components_by_names(self, names: List[str], collection: Optional[str] = None):
8781030
"""
8791031
Get components by a list of names, optionally filtered by collection.
880-
"""
881-
ids = self.get_ids(names, collection)
882-
return self.get_components_by_ids(ids)
883-
884-
885-
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
886-
"""Summarizes a dictionary by finding common prefixes that share the same value.
887-
888-
For a dictionary with dot-separated keys like: {
889-
'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
890-
'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
891-
'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
892-
}
893-
894-
Returns a dictionary where keys are the shortest common prefixes and values are their shared values: {
895-
'down_blocks': [0.6], 'up_blocks': [0.3]
896-
}
897-
"""
898-
# First group by values - convert lists to tuples to make them hashable
899-
value_to_keys = {}
900-
for key, value in d.items():
901-
value_tuple = tuple(value) if isinstance(value, list) else value
902-
if value_tuple not in value_to_keys:
903-
value_to_keys[value_tuple] = []
904-
value_to_keys[value_tuple].append(key)
9051032
906-
def find_common_prefix(keys: List[str]) -> str:
907-
"""Find the shortest common prefix among a list of dot-separated keys."""
908-
if not keys:
909-
return ""
910-
if len(keys) == 1:
911-
return keys[0]
912-
913-
# Split all keys into parts
914-
key_parts = [k.split(".") for k in keys]
915-
916-
# Find how many initial parts are common
917-
common_length = 0
918-
for parts in zip(*key_parts):
919-
if len(set(parts)) == 1: # All parts at this position are the same
920-
common_length += 1
921-
else:
922-
break
923-
924-
if common_length == 0:
925-
return ""
1033+
Args:
1034+
names (List[str]): List of component names
1035+
collection (Optional[str]): Optional collection to filter by
9261036
927-
# Return the common prefix
928-
return ".".join(key_parts[0][:common_length])
1037+
Returns:
1038+
Dict[str, Any]: Dictionary of components with component names as keys
9291039
930-
# Create summary by finding common prefixes for each value group
931-
summary = {}
932-
for value_tuple, keys in value_to_keys.items():
933-
prefix = find_common_prefix(keys)
934-
if prefix: # Only add if we found a common prefix
935-
# Convert tuple back to list if it was originally a list
936-
value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple
937-
summary[prefix] = value
938-
else:
939-
summary[""] = value # Use empty string if no common prefix
1040+
Raises:
1041+
ValueError: If duplicate component names are found in the search results
1042+
"""
1043+
ids = self.get_ids(names, collection)
1044+
return self.get_components_by_ids(ids)
9401045

941-
return summary

0 commit comments

Comments
 (0)