Skip to content

Commit 98ff01e

Browse files
Display progress and result URL directly on API nodes (#8102)
* [Luma] Print download URL of successful task result directly on nodes (#177) [Veo] Print download URL of successful task result directly on nodes (#184) [Recraft] Print download URL of successful task result directly on nodes (#183) [Pixverse] Print download URL of successful task result directly on nodes (#182) [Kling] Print download URL of successful task result directly on nodes (#181) [MiniMax] Print progress text and download URL of successful task result directly on nodes (#179) [Docs] Link to docs in `API_NODE` class property type annotation comment (#178) [Ideogram] Print download URL of successful task result directly on nodes (#176) [Kling] Print download URL of successful task result directly on nodes (#181) [Veo] Print download URL of successful task result directly on nodes (#184) [Recraft] Print download URL of successful task result directly on nodes (#183) [Pixverse] Print download URL of successful task result directly on nodes (#182) [MiniMax] Print progress text and download URL of successful task result directly on nodes (#179) [Docs] Link to docs in `API_NODE` class property type annotation comment (#178) [Luma] Print download URL of successful task result directly on nodes (#177) [Ideogram] Print download URL of successful task result directly on nodes (#176) Show output URL and progress text on Pika nodes (#168) [BFL] Print download URL of successful task result directly on nodes (#175) [OpenAI ] Print download URL of successful task result directly on nodes (#174) * fix ruff errors * fix 3.10 syntax error
1 parent bab836d commit 98ff01e

13 files changed

+474
-92
lines changed

comfy/comfy_types/node_typing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ class ComfyNodeABC(ABC):
235235
DEPRECATED: bool
236236
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
237237
API_NODE: Optional[bool]
238-
"""Flags a node as an API node."""
238+
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
239239

240240
@classmethod
241241
@abstractmethod

comfy_api_nodes/apinode_utils.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22
import io
33
import logging
4-
from typing import Optional
4+
from typing import Optional, Union
55
from comfy.utils import common_upscale
66
from comfy_api.input_impl import VideoFromFile
77
from comfy_api.util import VideoContainer, VideoCodec
@@ -15,6 +15,7 @@
1515
UploadRequest,
1616
UploadResponse,
1717
)
18+
from server import PromptServer
1819

1920

2021
import numpy as np
@@ -60,7 +61,9 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
6061
return s
6162

6263

63-
def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
64+
def validate_and_cast_response(
65+
response, timeout: int = None, node_id: Union[str, None] = None
66+
) -> torch.Tensor:
6467
"""Validates and casts a response to a torch.Tensor.
6568
6669
Args:
@@ -94,6 +97,10 @@ def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
9497
img = Image.open(io.BytesIO(img_data))
9598

9699
elif image_url:
100+
if node_id:
101+
PromptServer.instance.send_progress_text(
102+
f"Result URL: {image_url}", node_id
103+
)
97104
img_response = requests.get(image_url, timeout=timeout)
98105
if img_response.status_code != 200:
99106
raise ValueError("Failed to download the image")

comfy_api_nodes/apis/client.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
from pydantic import BaseModel, Field
104104
import uuid # For generating unique operation IDs
105105

106+
from server import PromptServer
106107
from comfy.cli_args import args
107108
from comfy import utils
108109
from . import request_logger
@@ -900,6 +901,7 @@ def __init__(
900901
failed_statuses: list,
901902
status_extractor: Callable[[R], str],
902903
progress_extractor: Callable[[R], float] = None,
904+
result_url_extractor: Callable[[R], str] = None,
903905
request: Optional[T] = None,
904906
api_base: str | None = None,
905907
auth_token: Optional[str] = None,
@@ -910,6 +912,8 @@ def __init__(
910912
max_retries: int = 3, # Max retries per individual API call
911913
retry_delay: float = 1.0,
912914
retry_backoff_factor: float = 2.0,
915+
estimated_duration: Optional[float] = None,
916+
node_id: Optional[str] = None,
913917
):
914918
self.poll_endpoint = poll_endpoint
915919
self.request = request
@@ -924,12 +928,15 @@ def __init__(
924928
self.max_retries = max_retries
925929
self.retry_delay = retry_delay
926930
self.retry_backoff_factor = retry_backoff_factor
931+
self.estimated_duration = estimated_duration
927932

928933
# Polling configuration
929934
self.status_extractor = status_extractor or (
930935
lambda x: getattr(x, "status", None)
931936
)
932937
self.progress_extractor = progress_extractor
938+
self.result_url_extractor = result_url_extractor
939+
self.node_id = node_id
933940
self.completed_statuses = completed_statuses
934941
self.failed_statuses = failed_statuses
935942

@@ -965,6 +972,26 @@ def execute(self, client: Optional[ApiClient] = None) -> R:
965972
except Exception as e:
966973
raise Exception(f"Error during polling: {str(e)}")
967974

975+
def _display_text_on_node(self, text: str):
976+
"""Sends text to the client which will be displayed on the node in the UI"""
977+
if not self.node_id:
978+
return
979+
980+
PromptServer.instance.send_progress_text(text, self.node_id)
981+
982+
def _display_time_progress_on_node(self, time_completed: int):
983+
if not self.node_id:
984+
return
985+
986+
if self.estimated_duration is not None:
987+
estimated_time_remaining = max(
988+
0, int(self.estimated_duration) - int(time_completed)
989+
)
990+
message = f"Task in progress: {time_completed:.0f}s (~{estimated_time_remaining:.0f}s remaining)"
991+
else:
992+
message = f"Task in progress: {time_completed:.0f}s"
993+
self._display_text_on_node(message)
994+
968995
def _check_task_status(self, response: R) -> TaskStatus:
969996
"""Check task status using the status extractor function"""
970997
try:
@@ -1031,7 +1058,15 @@ def _poll_until_complete(self, client: ApiClient) -> R:
10311058
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
10321059

10331060
if status == TaskStatus.COMPLETED:
1034-
logging.debug("[DEBUG] Task completed successfully")
1061+
message = "Task completed successfully"
1062+
if self.result_url_extractor:
1063+
result_url = self.result_url_extractor(response_obj)
1064+
if result_url:
1065+
message = f"Result URL: {result_url}"
1066+
else:
1067+
message = "Task completed successfully!"
1068+
logging.debug(f"[DEBUG] {message}")
1069+
self._display_text_on_node(message)
10351070
self.final_response = response_obj
10361071
if self.progress_extractor:
10371072
progress.update(100)
@@ -1047,7 +1082,10 @@ def _poll_until_complete(self, client: ApiClient) -> R:
10471082
logging.debug(
10481083
f"[DEBUG] Waiting {self.poll_interval} seconds before next poll"
10491084
)
1050-
time.sleep(self.poll_interval)
1085+
for i in range(int(self.poll_interval)):
1086+
time_completed = (poll_count * self.poll_interval) + i
1087+
self._display_time_progress_on_node(time_completed)
1088+
time.sleep(1)
10511089

10521090
except (LocalNetworkError, ApiServerError) as e:
10531091
# For network-related errors, increment error count and potentially abort

comfy_api_nodes/nodes_bfl.py

+44-19
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import io
22
from inspect import cleandoc
3+
from typing import Union
34
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
45
from comfy_api_nodes.apis.bfl_api import (
56
BFLStatus,
@@ -30,6 +31,7 @@
3031
import torch
3132
import base64
3233
import time
34+
from server import PromptServer
3335

3436

3537
def convert_mask_to_image(mask: torch.Tensor):
@@ -42,14 +44,19 @@ def convert_mask_to_image(mask: torch.Tensor):
4244

4345

4446
def handle_bfl_synchronous_operation(
45-
operation: SynchronousOperation, timeout_bfl_calls=360
47+
operation: SynchronousOperation,
48+
timeout_bfl_calls=360,
49+
node_id: Union[str, None] = None,
4650
):
4751
response_api: BFLFluxProGenerateResponse = operation.execute()
4852
return _poll_until_generated(
49-
response_api.polling_url, timeout=timeout_bfl_calls
53+
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
5054
)
5155

52-
def _poll_until_generated(polling_url: str, timeout=360):
56+
57+
def _poll_until_generated(
58+
polling_url: str, timeout=360, node_id: Union[str, None] = None
59+
):
5360
# used bfl-comfy-nodes to verify code implementation:
5461
# https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main
5562
start_time = time.time()
@@ -61,11 +68,21 @@ def _poll_until_generated(polling_url: str, timeout=360):
6168
request = requests.Request(method=HttpMethod.GET, url=polling_url)
6269
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
6370
while True:
71+
if node_id:
72+
time_elapsed = time.time() - start_time
73+
PromptServer.instance.send_progress_text(
74+
f"Generating ({time_elapsed:.0f}s)", node_id
75+
)
76+
6477
response = requests.Session().send(request.prepare())
6578
if response.status_code == 200:
6679
result = response.json()
6780
if result["status"] == BFLStatus.ready:
6881
img_url = result["result"]["sample"]
82+
if node_id:
83+
PromptServer.instance.send_progress_text(
84+
f"Result URL: {img_url}", node_id
85+
)
6986
img_response = requests.get(img_url)
7087
return process_image_response(img_response)
7188
elif result["status"] in [
@@ -180,6 +197,7 @@ def INPUT_TYPES(s):
180197
"hidden": {
181198
"auth_token": "AUTH_TOKEN_COMFY_ORG",
182199
"comfy_api_key": "API_KEY_COMFY_ORG",
200+
"unique_id": "UNIQUE_ID",
183201
},
184202
}
185203

@@ -212,6 +230,7 @@ def api_call(
212230
seed=0,
213231
image_prompt=None,
214232
image_prompt_strength=0.1,
233+
unique_id: Union[str, None] = None,
215234
**kwargs,
216235
):
217236
if image_prompt is None:
@@ -246,7 +265,7 @@ def api_call(
246265
),
247266
auth_kwargs=kwargs,
248267
)
249-
output_image = handle_bfl_synchronous_operation(operation)
268+
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
250269
return (output_image,)
251270

252271

@@ -320,6 +339,7 @@ def INPUT_TYPES(s):
320339
"hidden": {
321340
"auth_token": "AUTH_TOKEN_COMFY_ORG",
322341
"comfy_api_key": "API_KEY_COMFY_ORG",
342+
"unique_id": "UNIQUE_ID",
323343
},
324344
}
325345

@@ -338,6 +358,7 @@ def api_call(
338358
seed=0,
339359
image_prompt=None,
340360
# image_prompt_strength=0.1,
361+
unique_id: Union[str, None] = None,
341362
**kwargs,
342363
):
343364
image_prompt = (
@@ -363,7 +384,7 @@ def api_call(
363384
),
364385
auth_kwargs=kwargs,
365386
)
366-
output_image = handle_bfl_synchronous_operation(operation)
387+
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
367388
return (output_image,)
368389

369390

@@ -457,11 +478,11 @@ def INPUT_TYPES(s):
457478
},
458479
),
459480
},
460-
"optional": {
461-
},
481+
"optional": {},
462482
"hidden": {
463483
"auth_token": "AUTH_TOKEN_COMFY_ORG",
464484
"comfy_api_key": "API_KEY_COMFY_ORG",
485+
"unique_id": "UNIQUE_ID",
465486
},
466487
}
467488

@@ -483,6 +504,7 @@ def api_call(
483504
steps: int,
484505
guidance: float,
485506
seed=0,
507+
unique_id: Union[str, None] = None,
486508
**kwargs,
487509
):
488510
image = convert_image_to_base64(image)
@@ -508,7 +530,7 @@ def api_call(
508530
),
509531
auth_kwargs=kwargs,
510532
)
511-
output_image = handle_bfl_synchronous_operation(operation)
533+
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
512534
return (output_image,)
513535

514536

@@ -568,11 +590,11 @@ def INPUT_TYPES(s):
568590
},
569591
),
570592
},
571-
"optional": {
572-
},
593+
"optional": {},
573594
"hidden": {
574595
"auth_token": "AUTH_TOKEN_COMFY_ORG",
575596
"comfy_api_key": "API_KEY_COMFY_ORG",
597+
"unique_id": "UNIQUE_ID",
576598
},
577599
}
578600

@@ -591,13 +613,14 @@ def api_call(
591613
steps: int,
592614
guidance: float,
593615
seed=0,
616+
unique_id: Union[str, None] = None,
594617
**kwargs,
595618
):
596619
# prepare mask
597620
mask = resize_mask_to_image(mask, image)
598621
mask = convert_image_to_base64(convert_mask_to_image(mask))
599622
# make sure image will have alpha channel removed
600-
image = convert_image_to_base64(image[:,:,:,:3])
623+
image = convert_image_to_base64(image[:, :, :, :3])
601624

602625
operation = SynchronousOperation(
603626
endpoint=ApiEndpoint(
@@ -617,7 +640,7 @@ def api_call(
617640
),
618641
auth_kwargs=kwargs,
619642
)
620-
output_image = handle_bfl_synchronous_operation(operation)
643+
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
621644
return (output_image,)
622645

623646

@@ -702,11 +725,11 @@ def INPUT_TYPES(s):
702725
},
703726
),
704727
},
705-
"optional": {
706-
},
728+
"optional": {},
707729
"hidden": {
708730
"auth_token": "AUTH_TOKEN_COMFY_ORG",
709731
"comfy_api_key": "API_KEY_COMFY_ORG",
732+
"unique_id": "UNIQUE_ID",
710733
},
711734
}
712735

@@ -727,9 +750,10 @@ def api_call(
727750
steps: int,
728751
guidance: float,
729752
seed=0,
753+
unique_id: Union[str, None] = None,
730754
**kwargs,
731755
):
732-
control_image = convert_image_to_base64(control_image[:,:,:,:3])
756+
control_image = convert_image_to_base64(control_image[:, :, :, :3])
733757
preprocessed_image = None
734758

735759
# scale canny threshold between 0-500, to match BFL's API
@@ -765,7 +789,7 @@ def scale_value(value: float, min_val=0, max_val=500):
765789
),
766790
auth_kwargs=kwargs,
767791
)
768-
output_image = handle_bfl_synchronous_operation(operation)
792+
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
769793
return (output_image,)
770794

771795

@@ -830,11 +854,11 @@ def INPUT_TYPES(s):
830854
},
831855
),
832856
},
833-
"optional": {
834-
},
857+
"optional": {},
835858
"hidden": {
836859
"auth_token": "AUTH_TOKEN_COMFY_ORG",
837860
"comfy_api_key": "API_KEY_COMFY_ORG",
861+
"unique_id": "UNIQUE_ID",
838862
},
839863
}
840864

@@ -853,6 +877,7 @@ def api_call(
853877
steps: int,
854878
guidance: float,
855879
seed=0,
880+
unique_id: Union[str, None] = None,
856881
**kwargs,
857882
):
858883
control_image = convert_image_to_base64(control_image[:,:,:,:3])
@@ -880,7 +905,7 @@ def api_call(
880905
),
881906
auth_kwargs=kwargs,
882907
)
883-
output_image = handle_bfl_synchronous_operation(operation)
908+
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
884909
return (output_image,)
885910

886911

0 commit comments

Comments
 (0)