Skip to content

Commit 203fa04

Browse files
feat(nodes): support bottleneck flag for nodes
1 parent 954fce3 commit 203fa04

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

invokeai/app/invocations/baseinvocation.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,17 @@ class Classification(str, Enum, metaclass=MetaEnum):
7272
Special = "special"
7373

7474

75+
class Bottleneck(str, Enum, metaclass=MetaEnum):
76+
"""
77+
The bottleneck of an invocation.
78+
- `Network`: The invocation's execution is network-bound.
79+
- `GPU`: The invocation's execution is GPU-bound.
80+
"""
81+
82+
Network = "network"
83+
GPU = "gpu"
84+
85+
7586
class UIConfigBase(BaseModel):
7687
"""
7788
Provides additional node configuration to the UI.
@@ -241,6 +252,8 @@ def invoke_internal(self, context: InvocationContext, services: "InvocationServi
241252
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
242253
)
243254

255+
bottleneck: ClassVar[Bottleneck]
256+
244257
UIConfig: ClassVar[UIConfigBase]
245258

246259
model_config = ConfigDict(
@@ -399,6 +412,7 @@ def get_output_for_type(cls, output_type: str) -> type[BaseInvocationOutput] | N
399412
"use_cache",
400413
"type",
401414
"workflow",
415+
"bottleneck",
402416
}
403417

404418
RESERVED_INPUT_FIELD_NAMES = {"metadata", "board"}
@@ -483,6 +497,7 @@ def invocation(
483497
version: Optional[str] = None,
484498
use_cache: Optional[bool] = True,
485499
classification: Classification = Classification.Stable,
500+
bottleneck: Bottleneck = Bottleneck.GPU,
486501
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
487502
"""
488503
Registers an invocation.
@@ -494,6 +509,7 @@ def invocation(
494509
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
495510
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
496511
:param Classification classification: The classification of the invocation. Defaults to FeatureClassification.Stable. Use Beta or Prototype if the invocation is unstable.
512+
:param Bottleneck bottleneck: The bottleneck of the invocation. Defaults to Bottleneck.GPU. Use Network if the invocation is network-bound.
497513
"""
498514

499515
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
@@ -530,6 +546,8 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
530546
if use_cache is not None:
531547
cls.model_fields["use_cache"].default = use_cache
532548

549+
cls.bottleneck = bottleneck
550+
533551
# Add the invocation type to the model.
534552

535553
# You'd be tempted to just add the type field and rebuild the model, like this:

invokeai/invocation_api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from invokeai.app.invocations.baseinvocation import (
88
BaseInvocation,
99
BaseInvocationOutput,
10+
Bottleneck,
1011
Classification,
1112
invocation,
1213
invocation_output,
@@ -86,6 +87,7 @@
8687
# invokeai.app.invocations.baseinvocation
8788
"BaseInvocation",
8889
"BaseInvocationOutput",
90+
"Bottleneck",
8991
"Classification",
9092
"invocation",
9193
"invocation_output",

0 commit comments

Comments
 (0)