@@ -142,7 +142,7 @@ def custom_offload_with_hook(
142
142
user_hook .attach ()
143
143
return user_hook
144
144
145
-
145
+ # this is the class that user can customize to implement their own offload strategy
146
146
class AutoOffloadStrategy :
147
147
"""
148
148
Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on
@@ -213,7 +213,101 @@ def search_best_candidate(module_sizes, min_memory_offload):
213
213
return hooks_to_offload
214
214
215
215
216
+ # utils for display component info in a readable format
217
+ # TODO: move to a different file
218
+ def summarize_dict_by_value_and_parts (d : Dict [str , Any ]) -> Dict [str , Any ]:
219
+ """Summarizes a dictionary by finding common prefixes that share the same value.
220
+
221
+ For a dictionary with dot-separated keys like: {
222
+ 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
223
+ 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
224
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
225
+ }
226
+
227
+ Returns a dictionary where keys are the shortest common prefixes and values are their shared values: {
228
+ 'down_blocks': [0.6], 'up_blocks': [0.3]
229
+ }
230
+ """
231
+ # First group by values - convert lists to tuples to make them hashable
232
+ value_to_keys = {}
233
+ for key , value in d .items ():
234
+ value_tuple = tuple (value ) if isinstance (value , list ) else value
235
+ if value_tuple not in value_to_keys :
236
+ value_to_keys [value_tuple ] = []
237
+ value_to_keys [value_tuple ].append (key )
238
+
239
+ def find_common_prefix (keys : List [str ]) -> str :
240
+ """Find the shortest common prefix among a list of dot-separated keys."""
241
+ if not keys :
242
+ return ""
243
+ if len (keys ) == 1 :
244
+ return keys [0 ]
245
+
246
+ # Split all keys into parts
247
+ key_parts = [k .split ("." ) for k in keys ]
248
+
249
+ # Find how many initial parts are common
250
+ common_length = 0
251
+ for parts in zip (* key_parts ):
252
+ if len (set (parts )) == 1 : # All parts at this position are the same
253
+ common_length += 1
254
+ else :
255
+ break
256
+
257
+ if common_length == 0 :
258
+ return ""
259
+
260
+ # Return the common prefix
261
+ return "." .join (key_parts [0 ][:common_length ])
262
+
263
+ # Create summary by finding common prefixes for each value group
264
+ summary = {}
265
+ for value_tuple , keys in value_to_keys .items ():
266
+ prefix = find_common_prefix (keys )
267
+ if prefix : # Only add if we found a common prefix
268
+ # Convert tuple back to list if it was originally a list
269
+ value = list (value_tuple ) if isinstance (d [keys [0 ]], list ) else value_tuple
270
+ summary [prefix ] = value
271
+ else :
272
+ summary ["" ] = value # Use empty string if no common prefix
273
+
274
+ return summary
275
+
276
+
216
277
class ComponentsManager :
278
+ """
279
+ A central registry and management system for model components across multiple pipelines.
280
+
281
+ [`ComponentsManager`] provides a unified way to register, track, and reuse model components
282
+ (like UNet, VAE, text encoders, etc.) across different modular pipelines. It includes
283
+ features for duplicate detection, memory management, and component organization.
284
+
285
+ <Tip warning={true}>
286
+
287
+ This is an experimental feature and is likely to change in the future.
288
+
289
+ </Tip>
290
+
291
+ Example:
292
+ ```python
293
+ from diffusers import ComponentsManager
294
+
295
+ # Create a components manager
296
+ cm = ComponentsManager()
297
+
298
+ # Add components
299
+ cm.add("unet", unet_model, collection="sdxl")
300
+ cm.add("vae", vae_model, collection="sdxl")
301
+
302
+ # Enable auto offloading
303
+ cm.enable_auto_cpu_offload(device="cuda")
304
+
305
+ # Retrieve components
306
+ unet = cm.get_one(name="unet", collection="sdxl")
307
+ ```
308
+ """
309
+
310
+
217
311
_available_info_fields = [
218
312
"model_id" ,
219
313
"added_time" ,
@@ -278,7 +372,19 @@ def _lookup_ids(
278
372
def _id_to_name (component_id : str ):
279
373
return "_" .join (component_id .split ("_" )[:- 1 ])
280
374
281
- def add (self , name , component , collection : Optional [str ] = None ):
375
+ def add (self , name : str , component : Any , collection : Optional [str ] = None ):
376
+ """
377
+ Add a component to the ComponentsManager.
378
+
379
+ Args:
380
+ name (str): The name of the component
381
+ component (Any): The component to add
382
+ collection (Optional[str]): The collection to add the component to
383
+
384
+ Returns:
385
+ str: The unique component ID, which is generated as "{name}_{id(component)}" where
386
+ id(component) is Python's built-in unique identifier for the object
387
+ """
282
388
component_id = f"{ name } _{ id (component )} "
283
389
284
390
# check for duplicated components
@@ -334,6 +440,12 @@ def add(self, name, component, collection: Optional[str] = None):
334
440
return component_id
335
441
336
442
def remove (self , component_id : str = None ):
443
+ """
444
+ Remove a component from the ComponentsManager.
445
+
446
+ Args:
447
+ component_id (str): The ID of the component to remove
448
+ """
337
449
if component_id not in self .components :
338
450
logger .warning (f"Component '{ component_id } ' not found in ComponentsManager" )
339
451
return
@@ -545,6 +657,22 @@ def matches_pattern(component_id, pattern, exact_match=False):
545
657
return get_return_dict (matches , return_dict_with_names )
546
658
547
659
def enable_auto_cpu_offload (self , device : Union [str , int , torch .device ] = "cuda" , memory_reserve_margin = "3GB" ):
660
+ """
661
+ Enable automatic CPU offloading for all components.
662
+
663
+ The algorithm works as follows:
664
+ 1. All models start on CPU by default
665
+ 2. When a model's forward pass is called, it's moved to the execution device
666
+ 3. If there's insufficient memory, other models on the device are moved back to CPU
667
+ 4. The system tries to offload the smallest combination of models that frees enough memory
668
+ 5. Models stay on the execution device until another model needs memory and forces them off
669
+
670
+ Args:
671
+ device (Union[str, int, torch.device]): The execution device where models are moved for forward passes
672
+ memory_reserve_margin (str): The memory reserve margin to use, default is 3GB. This is the amount of
673
+ memory to keep free on the device to avoid running out of memory during
674
+ model execution (e.g., for intermediate activations, gradients, etc.)
675
+ """
548
676
if not is_accelerate_available ():
549
677
raise ImportError ("Make sure to install accelerate to use auto_cpu_offload" )
550
678
@@ -574,6 +702,9 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda"
574
702
self ._auto_offload_device = device
575
703
576
704
def disable_auto_cpu_offload (self ):
705
+ """
706
+ Disable automatic CPU offloading for all components.
707
+ """
577
708
if self .model_hooks is None :
578
709
self ._auto_offload_enabled = False
579
710
return
@@ -595,13 +726,12 @@ def get_model_info(
595
726
"""Get comprehensive information about a component.
596
727
597
728
Args:
598
- component_id: Name of the component to get info for
599
- fields: Optional field (s) to return. Can be a string for single field or list of fields.
729
+ component_id (str) : Name of the component to get info for
730
+ fields ( Optional[Union[str, List[str]]]): Field (s) to return. Can be a string for single field or list of fields.
600
731
If None, uses the available_info_fields setting.
601
732
602
733
Returns:
603
- Dictionary containing requested component metadata. If fields is specified, returns only those fields. If a
604
- single field is requested as string, returns just that field's value.
734
+ Dictionary containing requested component metadata. If fields is specified, returns only those fields. Otherwise, returns all fields.
605
735
"""
606
736
if component_id not in self .components :
607
737
raise ValueError (f"Component '{ component_id } ' not found in ComponentsManager" )
@@ -808,15 +938,16 @@ def get_one(
808
938
load_id : Optional [str ] = None ,
809
939
) -> Any :
810
940
"""
811
- Get a single component by either: (1) searching name (pattern matching), collection, or load_id. (2) passing in
812
- a component_id Raises an error if multiple components match or none are found. support pattern matching for
813
- name
941
+ Get a single component by either:
942
+ - searching name (pattern matching), collection, or load_id.
943
+ - passing in a component_id
944
+ Raises an error if multiple components match or none are found.
814
945
815
946
Args:
816
- component_id: Optional component ID to get
817
- name: Component name or pattern
818
- collection: Optional collection to filter by
819
- load_id: Optional load_id to filter by
947
+ component_id (Optional[str]) : Optional component ID to get
948
+ name (Optional[str]) : Component name or pattern
949
+ collection (Optional[str]) : Optional collection to filter by
950
+ load_id (Optional[str]) : Optional load_id to filter by
820
951
821
952
Returns:
822
953
A single component
@@ -847,6 +978,13 @@ def get_one(
847
978
def get_ids (self , names : Union [str , List [str ]] = None , collection : Optional [str ] = None ):
848
979
"""
849
980
Get component IDs by a list of names, optionally filtered by collection.
981
+
982
+ Args:
983
+ names (Union[str, List[str]]): List of component names
984
+ collection (Optional[str]): Optional collection to filter by
985
+
986
+ Returns:
987
+ List[str]: List of component IDs
850
988
"""
851
989
ids = set ()
852
990
if not isinstance (names , list ):
@@ -858,6 +996,20 @@ def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str]
858
996
def get_components_by_ids (self , ids : List [str ], return_dict_with_names : Optional [bool ] = True ):
859
997
"""
860
998
Get components by a list of IDs.
999
+
1000
+ Args:
1001
+ ids (List[str]):
1002
+ List of component IDs
1003
+ return_dict_with_names (Optional[bool]):
1004
+ Whether to return a dictionary with component names as keys:
1005
+
1006
+ Returns:
1007
+ Dict[str, Any]: Dictionary of components.
1008
+ - If return_dict_with_names=True, keys are component names.
1009
+ - If return_dict_with_names=False, keys are component IDs.
1010
+
1011
+ Raises:
1012
+ ValueError: If duplicate component names are found in the search results when return_dict_with_names=True
861
1013
"""
862
1014
components = {id : self .components [id ] for id in ids }
863
1015
@@ -877,65 +1029,17 @@ def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional
877
1029
def get_components_by_names (self , names : List [str ], collection : Optional [str ] = None ):
878
1030
"""
879
1031
Get components by a list of names, optionally filtered by collection.
880
- """
881
- ids = self .get_ids (names , collection )
882
- return self .get_components_by_ids (ids )
883
-
884
-
885
- def summarize_dict_by_value_and_parts (d : Dict [str , Any ]) -> Dict [str , Any ]:
886
- """Summarizes a dictionary by finding common prefixes that share the same value.
887
-
888
- For a dictionary with dot-separated keys like: {
889
- 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
890
- 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
891
- 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
892
- }
893
-
894
- Returns a dictionary where keys are the shortest common prefixes and values are their shared values: {
895
- 'down_blocks': [0.6], 'up_blocks': [0.3]
896
- }
897
- """
898
- # First group by values - convert lists to tuples to make them hashable
899
- value_to_keys = {}
900
- for key , value in d .items ():
901
- value_tuple = tuple (value ) if isinstance (value , list ) else value
902
- if value_tuple not in value_to_keys :
903
- value_to_keys [value_tuple ] = []
904
- value_to_keys [value_tuple ].append (key )
905
1032
906
- def find_common_prefix (keys : List [str ]) -> str :
907
- """Find the shortest common prefix among a list of dot-separated keys."""
908
- if not keys :
909
- return ""
910
- if len (keys ) == 1 :
911
- return keys [0 ]
912
-
913
- # Split all keys into parts
914
- key_parts = [k .split ("." ) for k in keys ]
915
-
916
- # Find how many initial parts are common
917
- common_length = 0
918
- for parts in zip (* key_parts ):
919
- if len (set (parts )) == 1 : # All parts at this position are the same
920
- common_length += 1
921
- else :
922
- break
923
-
924
- if common_length == 0 :
925
- return ""
1033
+ Args:
1034
+ names (List[str]): List of component names
1035
+ collection (Optional[str]): Optional collection to filter by
926
1036
927
- # Return the common prefix
928
- return "." . join ( key_parts [ 0 ][: common_length ])
1037
+ Returns:
1038
+ Dict[str, Any]: Dictionary of components with component names as keys
929
1039
930
- # Create summary by finding common prefixes for each value group
931
- summary = {}
932
- for value_tuple , keys in value_to_keys .items ():
933
- prefix = find_common_prefix (keys )
934
- if prefix : # Only add if we found a common prefix
935
- # Convert tuple back to list if it was originally a list
936
- value = list (value_tuple ) if isinstance (d [keys [0 ]], list ) else value_tuple
937
- summary [prefix ] = value
938
- else :
939
- summary ["" ] = value # Use empty string if no common prefix
1040
+ Raises:
1041
+ ValueError: If duplicate component names are found in the search results
1042
+ """
1043
+ ids = self .get_ids (names , collection )
1044
+ return self .get_components_by_ids (ids )
940
1045
941
- return summary
0 commit comments