Skip to content

Commit 595581d

Browse files
committed
style
1 parent d27b654 commit 595581d

File tree

3 files changed

+68
-52
lines changed

3 files changed

+68
-52
lines changed

src/diffusers/modular_pipelines/components_manager.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def custom_offload_with_hook(
142142
user_hook.attach()
143143
return user_hook
144144

145+
145146
# this is the class that user can customize to implement their own offload strategy
146147
class AutoOffloadStrategy:
147148
"""
@@ -277,37 +278,36 @@ def find_common_prefix(keys: List[str]) -> str:
277278
class ComponentsManager:
278279
"""
279280
A central registry and management system for model components across multiple pipelines.
280-
281-
[`ComponentsManager`] provides a unified way to register, track, and reuse model components
282-
(like UNet, VAE, text encoders, etc.) across different modular pipelines. It includes
283-
features for duplicate detection, memory management, and component organization.
284-
281+
282+
[`ComponentsManager`] provides a unified way to register, track, and reuse model components (like UNet, VAE, text
283+
encoders, etc.) across different modular pipelines. It includes features for duplicate detection, memory
284+
management, and component organization.
285+
285286
<Tip warning={true}>
286287
287288
This is an experimental feature and is likely to change in the future.
288289
289290
</Tip>
290-
291+
291292
Example:
292293
```python
293294
from diffusers import ComponentsManager
294-
295+
295296
# Create a components manager
296297
cm = ComponentsManager()
297-
298+
298299
# Add components
299300
cm.add("unet", unet_model, collection="sdxl")
300301
cm.add("vae", vae_model, collection="sdxl")
301-
302+
302303
# Enable auto offloading
303304
cm.enable_auto_cpu_offload(device="cuda")
304-
305+
305306
# Retrieve components
306307
unet = cm.get_one(name="unet", collection="sdxl")
307308
```
308309
"""
309310

310-
311311
_available_info_fields = [
312312
"model_id",
313313
"added_time",
@@ -382,7 +382,7 @@ def add(self, name: str, component: Any, collection: Optional[str] = None):
382382
collection (Optional[str]): The collection to add the component to
383383
384384
Returns:
385-
str: The unique component ID, which is generated as "{name}_{id(component)}" where
385+
str: The unique component ID, which is generated as "{name}_{id(component)}" where
386386
id(component) is Python's built-in unique identifier for the object
387387
"""
388388
component_id = f"{name}_{id(component)}"
@@ -669,9 +669,9 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda"
669669
670670
Args:
671671
device (Union[str, int, torch.device]): The execution device where models are moved for forward passes
672-
memory_reserve_margin (str): The memory reserve margin to use, default is 3GB. This is the amount of
673-
memory to keep free on the device to avoid running out of memory during
674-
model execution (e.g., for intermediate activations, gradients, etc.)
672+
memory_reserve_margin (str): The memory reserve margin to use, default is 3GB. This is the amount of
673+
memory to keep free on the device to avoid running out of memory during model
674+
execution (e.g., for intermediate activations, gradients, etc.)
675675
"""
676676
if not is_accelerate_available():
677677
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
@@ -727,11 +727,13 @@ def get_model_info(
727727
728728
Args:
729729
component_id (str): Name of the component to get info for
730-
fields (Optional[Union[str, List[str]]]): Field(s) to return. Can be a string for single field or list of fields.
731-
If None, uses the available_info_fields setting.
730+
fields (Optional[Union[str, List[str]]]):
731+
Field(s) to return. Can be a string for single field or list of fields. If None, uses the
732+
available_info_fields setting.
732733
733734
Returns:
734-
Dictionary containing requested component metadata. If fields is specified, returns only those fields. Otherwise, returns all fields.
735+
Dictionary containing requested component metadata. If fields is specified, returns only those fields.
736+
Otherwise, returns all fields.
735737
"""
736738
if component_id not in self.components:
737739
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
@@ -938,10 +940,10 @@ def get_one(
938940
load_id: Optional[str] = None,
939941
) -> Any:
940942
"""
941-
Get a single component by either:
942-
- searching name (pattern matching), collection, or load_id.
943+
Get a single component by either:
944+
- searching name (pattern matching), collection, or load_id.
943945
- passing in a component_id
944-
Raises an error if multiple components match or none are found.
946+
Raises an error if multiple components match or none are found.
945947
946948
Args:
947949
component_id (Optional[str]): Optional component ID to get
@@ -998,13 +1000,13 @@ def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional
9981000
Get components by a list of IDs.
9991001
10001002
Args:
1001-
ids (List[str]):
1003+
ids (List[str]):
10021004
List of component IDs
10031005
return_dict_with_names (Optional[bool]):
10041006
Whether to return a dictionary with component names as keys:
10051007
10061008
Returns:
1007-
Dict[str, Any]: Dictionary of components.
1009+
Dict[str, Any]: Dictionary of components.
10081010
- If return_dict_with_names=True, keys are component names.
10091011
- If return_dict_with_names=False, keys are component IDs.
10101012
@@ -1042,4 +1044,3 @@ def get_components_by_names(self, names: List[str], collection: Optional[str] =
10421044
"""
10431045
ids = self.get_ids(names, collection)
10441046
return self.get_components_by_ids(ids)
1045-

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,9 @@ def set_input(self, key: str, value: Any, kwargs_type: str = None):
8585
"""
8686
Add an input to the immutable pipeline state, i.e, pipeline_state.inputs.
8787
88-
The kwargs_type parameter allows you to associate inputs with specific input types.
89-
For example, if you call set_input(prompt_embeds=..., kwargs_type="guider_kwargs"),
90-
this input will be automatically fetched when a pipeline block has "guider_kwargs"
91-
in its expected_inputs list.
88+
The kwargs_type parameter allows you to associate inputs with specific input types. For example, if you call
89+
set_input(prompt_embeds=..., kwargs_type="guider_kwargs"), this input will be automatically fetched when a
90+
pipeline block has "guider_kwargs" in its expected_inputs list.
9291
9392
Args:
9493
key (str): The key for the input
@@ -106,10 +105,9 @@ def set_intermediate(self, key: str, value: Any, kwargs_type: str = None):
106105
"""
107106
Add an intermediate value to the mutable pipeline state, i.e, pipeline_state.intermediates.
108107
109-
The kwargs_type parameter allows you to associate intermediate values with specific input types.
110-
For example, if you call set_intermediate(latents=..., kwargs_type="latents_kwargs"),
111-
this intermediate value will be automatically fetched when a pipeline block has "latents_kwargs"
112-
in its expected_intermediate_inputs list.
108+
The kwargs_type parameter allows you to associate intermediate values with specific input types. For example,
109+
if you call set_intermediate(latents=..., kwargs_type="latents_kwargs"), this intermediate value will be
110+
automatically fetched when a pipeline block has "latents_kwargs" in its expected_intermediate_inputs list.
113111
114112
Args:
115113
key (str): The key for the intermediate value
@@ -414,13 +412,13 @@ def init_pipeline(
414412
collection=collection,
415413
)
416414
return modular_pipeline
417-
415+
418416
@staticmethod
419417
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
420418
"""
421-
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current
422-
default value is None and new default value is not None. Warns if multiple non-None default values exist for the
423-
same input.
419+
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if
420+
current default value is None and new default value is not None. Warns if multiple non-None default values
421+
exist for the same input.
424422
425423
Args:
426424
named_input_lists: List of tuples containing (block_name, input_param_list) pairs
@@ -482,8 +480,6 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) ->
482480
return list(combined_dict.values())
483481

484482

485-
486-
487483
class PipelineBlock(ModularPipelineBlocks):
488484
"""
489485
A Pipeline Block is the basic building block of a Modular Pipeline.
@@ -499,15 +495,33 @@ class PipelineBlock(ModularPipelineBlocks):
499495
500496
Args:
501497
description (str, optional): A description of the block, defaults to None. Define as a property in subclasses.
502-
expected_components (List[ComponentSpec], optional): A list of components that are expected to be used in the block, defaults to []. To override, define as a property in subclasses.
503-
expected_configs (List[ConfigSpec], optional): A list of configs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses.
504-
inputs (List[InputParam], optional): A list of inputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses.
505-
intermediate_inputs (List[InputParam], optional): A list of intermediate inputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses.
506-
intermediate_outputs (List[OutputParam], optional): A list of intermediate outputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses.
507-
outputs (List[OutputParam], optional): A list of outputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses.
508-
required_inputs (List[str], optional): A list of required inputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses.
509-
required_intermediate_inputs (List[str], optional): A list of required intermediate inputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses.
510-
required_intermediate_outputs (List[str], optional): A list of required intermediate outputs that are expected to be used in the block, defaults to []. To override, define as a property in subclasses.
498+
expected_components (List[ComponentSpec], optional):
499+
A list of components that are expected to be used in the block, defaults to []. To override, define as a
500+
property in subclasses.
501+
expected_configs (List[ConfigSpec], optional):
502+
A list of configs that are expected to be used in the block, defaults to []. To override, define as a
503+
property in subclasses.
504+
inputs (List[InputParam], optional):
505+
A list of inputs that are expected to be used in the block, defaults to []. To override, define as a
506+
property in subclasses.
507+
intermediate_inputs (List[InputParam], optional):
508+
A list of intermediate inputs that are expected to be used in the block, defaults to []. To override,
509+
define as a property in subclasses.
510+
intermediate_outputs (List[OutputParam], optional):
511+
A list of intermediate outputs that are expected to be used in the block, defaults to []. To override,
512+
define as a property in subclasses.
513+
outputs (List[OutputParam], optional):
514+
A list of outputs that are expected to be used in the block, defaults to []. To override, define as a
515+
property in subclasses.
516+
required_inputs (List[str], optional):
517+
A list of required inputs that are expected to be used in the block, defaults to []. To override, define as
518+
a property in subclasses.
519+
required_intermediate_inputs (List[str], optional):
520+
A list of required intermediate inputs that are expected to be used in the block, defaults to []. To
521+
override, define as a property in subclasses.
522+
required_intermediate_outputs (List[str], optional):
523+
A list of required intermediate outputs that are expected to be used in the block, defaults to []. To
524+
override, define as a property in subclasses.
511525
"""
512526

513527
model_name = None
@@ -997,7 +1011,8 @@ def doc(self):
9971011

9981012
class SequentialPipelineBlocks(ModularPipelineBlocks):
9991013
"""
1000-
A Pipeline Blocks that combines multiple pipeline block classes into one. When called, it will call each block in sequence.
1014+
A Pipeline Blocks that combines multiple pipeline block classes into one. When called, it will call each block in
1015+
sequence.
10011016
10021017
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
10031018
library implements for all the pipeline blocks (such as loading or saving etc.)
@@ -1373,8 +1388,8 @@ def doc(self):
13731388

13741389
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
13751390
"""
1376-
A Pipeline blocks that combines multiple pipeline block classes into a For Loop. When called, it will call each block in
1377-
sequence.
1391+
A Pipeline blocks that combines multiple pipeline block classes into a For Loop. When called, it will call each
1392+
block in sequence.
13781393
13791394
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
13801395
library implements for all the pipeline blocks (such as loading or saving etc.)

src/diffusers/modular_pipelines/node_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,8 @@ def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
348348

349349
class ModularNode(ConfigMixin):
350350
"""
351-
A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon.
352-
It is a wrapper around a ModularPipelineBlocks object.
351+
A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper
352+
around a ModularPipelineBlocks object.
353353
354354
<Tip warning={true}>
355355

0 commit comments

Comments
 (0)