Skip to content

Commit c6fc860

Browse files
committed
fix(api): make panorama work with prompt alternatives
1 parent b8b73d8 commit c6fc860

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

api/onnx_web/diffusers/pipelines/panorama.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import inspect
16+
from logging import getLogger
1617
from typing import Callable, List, Optional, Union
1718

1819
import PIL
@@ -299,7 +300,7 @@ def get_views(self, panorama_height, panorama_width, window_size=64, stride=8):
299300
views.append((h_start, h_end, w_start, w_end))
300301
return views
301302

302-
def __call__(
303+
def text2img(
303304
self,
304305
prompt: Union[str, List[str]] = None,
305306
height: Optional[int] = 512,
@@ -635,11 +636,11 @@ def img2img(
635636
# prep image
636637
image = preprocess(image).cpu().numpy()
637638
image = image.astype(latents_dtype)
639+
638640
# encode the init image into latents and scale the latents
639641
latents = self.vae_encoder(sample=image)[0]
640642
latents = 0.18215 * latents
641-
642-
latents = latents * np.float64(self.scheduler.init_noise_sigma)
643+
# latents = latents * np.float64(self.scheduler.init_noise_sigma)
643644

644645
# get the original timestep using init_timestep
645646
offset = self.scheduler.config.get("steps_offset", 0)
@@ -746,3 +747,15 @@ def img2img(
746747
return (image, has_nsfw_concept)
747748

748749
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
750+
751+
def __call__(
752+
self,
753+
*args,
754+
**kwargs,
755+
):
756+
if len(args) > 0 and (isinstance(args[0], np.ndarray) or isinstance(args[0], PIL.Image.Image)):
757+
logger.debug("running img2img panorama pipeline")
758+
return self.img2img(*args, **kwargs)
759+
else:
760+
logger.debug("running txt2img panorama pipeline")
761+
return self.text2img(*args, **kwargs)

api/onnx_web/diffusers/run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def run_loopback(
6161
)
6262

6363
def loopback_iteration(source: Image.Image):
64-
if pipe_type in ["lpw", "panorama"]:
64+
if pipe_type == "lpw":
6565
rng = torch.manual_seed(params.seed)
6666
result = pipe.img2img(
6767
source,
@@ -174,7 +174,7 @@ def highres_tile(tile: Image.Image, dims):
174174
callback=highres_progress,
175175
)
176176

177-
if pipe_type in ["lpw", "panorama"]:
177+
if pipe_type == "lpw":
178178
rng = torch.manual_seed(params.seed)
179179
result = highres_pipe.img2img(
180180
tile,
@@ -250,7 +250,7 @@ def run_txt2img_pipeline(
250250
)
251251
progress = job.get_progress_callback()
252252

253-
if pipe_type in ["lpw", "panorama"]:
253+
if pipe_type == "lpw":
254254
rng = torch.manual_seed(params.seed)
255255
result = pipe.text2img(
256256
params.prompt,
@@ -369,7 +369,7 @@ def run_img2img_pipeline(
369369
pipe_params["image_guidance_scale"] = strength
370370

371371
progress = job.get_progress_callback()
372-
if pipe_type in ["lpw", "panorama"]:
372+
if pipe_type == "lpw":
373373
logger.debug("using LPW pipeline for img2img")
374374
rng = torch.manual_seed(params.seed)
375375
result = pipe.img2img(

0 commit comments

Comments
 (0)