|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import inspect |
| 16 | +from logging import getLogger |
16 | 17 | from typing import Callable, List, Optional, Union |
17 | 18 |
|
18 | 19 | import PIL |
@@ -299,7 +300,7 @@ def get_views(self, panorama_height, panorama_width, window_size=64, stride=8): |
299 | 300 | views.append((h_start, h_end, w_start, w_end)) |
300 | 301 | return views |
301 | 302 |
|
302 | | - def __call__( |
| 303 | + def text2img( |
303 | 304 | self, |
304 | 305 | prompt: Union[str, List[str]] = None, |
305 | 306 | height: Optional[int] = 512, |
@@ -635,11 +636,11 @@ def img2img( |
635 | 636 | # prep image |
636 | 637 | image = preprocess(image).cpu().numpy() |
637 | 638 | image = image.astype(latents_dtype) |
| 639 | + |
638 | 640 | # encode the init image into latents and scale the latents |
639 | 641 | latents = self.vae_encoder(sample=image)[0] |
640 | 642 | latents = 0.18215 * latents |
641 | | - |
642 | | - latents = latents * np.float64(self.scheduler.init_noise_sigma) |
| 643 | + # latents = latents * np.float64(self.scheduler.init_noise_sigma) |
643 | 644 |
|
644 | 645 | # get the original timestep using init_timestep |
645 | 646 | offset = self.scheduler.config.get("steps_offset", 0) |
@@ -746,3 +747,15 @@ def img2img( |
746 | 747 | return (image, has_nsfw_concept) |
747 | 748 |
|
748 | 749 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
| 750 | + |
| 751 | + def __call__( |
| 752 | + self, |
| 753 | + *args, |
| 754 | + **kwargs, |
| 755 | + ): |
| 756 | + if len(args) > 0 and (isinstance(args[0], np.ndarray) or isinstance(args[0], PIL.Image.Image)): |
| 757 | + logger.debug("running img2img panorama pipeline") |
| 758 | + return self.img2img(*args, **kwargs) |
| 759 | + else: |
| 760 | + logger.debug("running txt2img panorama pipeline") |
| 761 | + return self.text2img(*args, **kwargs) |
0 commit comments