@@ -95,7 +95,7 @@ def __init__(
9595 self ,
9696 host : str = "grpc.stability.ai:443" ,
9797 key : str = "" ,
98- engine : str = "stable-diffusion-xl-beta-v2-2-2 " ,
98+ engine : str = "stable-diffusion-xl-1024-v1-0 " ,
9999 upscale_engine : str = "esrgan-v1-x2plus" ,
100100 enhance_engine : str = "face-enhance-v1" ,
101101 verbose : bool = False ,
@@ -157,8 +157,8 @@ def generate(
157157 prompt : Union [str , List [str ], generation .Prompt , List [generation .Prompt ]],
158158 init_image : Optional [Image .Image ] = None ,
159159 mask_image : Optional [Image .Image ] = None ,
160- height : int = 512 ,
161- width : int = 512 ,
160+ height : int = 1024 ,
161+ width : int = 1024 ,
162162 start_schedule : float = 1.0 ,
163163 end_schedule : float = 0.01 ,
164164 cfg_scale : float = 7.0 ,
@@ -175,6 +175,10 @@ def generate(
175175 guidance_models : List [str ] = None ,
176176 upscale : Union [bool , Dict [str , Any ]] = False ,
177177 enhance : Union [bool , Dict [str , Any ]] = False ,
178+ adapter_type : generation .T2IAdapter = None ,
179+ adapter_strength : float = 0.4 ,
180+ adapter_init_type : generation .T2IAdapterInit = generation .T2IADAPTERINIT_IMAGE ,
181+ style_preset : Optional [str ] = None
178182 ) -> Generator [generation .Answer , None , None ]:
179183 """
180184 Generate images from a prompt.
@@ -200,6 +204,11 @@ def generate(
200204 :param guidance_models: Models to use for guidance.
201205 :param upscale: Whether to upscale the generated images. Can also pass a dictionary of upscale arguments. See client.upscale for supported values.
202206 :param enhance: Whether to enhance the generated images. Can also pass a dictionary of enhance arguments. See client._make_enhance_request for supported values.
207+ :param adapter_type: T2I adapter type, if any.
208+ :param adapter_strength: Float between 0, 1 representing the proportion of unet passes into which we inject adapter weights
209+ :param adapter_init_type: If T2IADAPTERINIT_IMAGE then init_image is converted into an initialising image corresponding to the adapter_type. i.e.
210+ a sketch/depthmap/canny edge. If T2IADAPTERINIT_ADAPTER_IMAGE, then the init_image is treated as already a a sketch/depthmap/canny edge.
211+ :param style_preset: Style preset name to use (see https://platform.stability.ai/docs/api-reference#tag/v1generation)
203212 :return: Generator of Answer objects.
204213 """
205214
@@ -220,7 +229,11 @@ def generate(
220229 guidance_cuts = guidance_cuts ,
221230 guidance_strength = guidance_strength ,
222231 guidance_prompt = guidance_prompt ,
223- guidance_models = guidance_models
232+ guidance_models = guidance_models ,
233+ adapter_type = adapter_type ,
234+ adapter_strength = adapter_strength ,
235+ adapter_init_type = adapter_init_type ,
236+ style_preset = style_preset
224237 )
225238
226239 if not upscale and not enhance :
@@ -378,6 +391,10 @@ def _make_generate_request(
378391 guidance_strength : Optional [float ] = None ,
379392 guidance_prompt : Union [str , generation .Prompt ] = None ,
380393 guidance_models : List [str ] = None ,
394+ adapter_type : generation .T2IAdapter = None ,
395+ adapter_strength : float = 0.4 ,
396+ adapter_init_type : generation .T2IAdapterInit = generation .T2IADAPTERINIT_IMAGE ,
397+ style_preset : Optional [str ] = None
381398 ):
382399 """
383400 Create a generate request
@@ -435,8 +452,7 @@ def _make_generate_request(
435452 raise ValueError ("guidance_prompt must be a string or Prompt object" )
436453 if guidance_strength == 0.0 :
437454 guidance_strength = None
438-
439-
455+
440456 # Build our CLIP parameters
441457 if guidance_preset is not generation .GUIDANCE_PRESET_NONE :
442458 # to do: make it so user can override this
@@ -464,6 +480,12 @@ def _make_generate_request(
464480 ],
465481 )
466482
483+ adapter_parameters = generation .T2IAdapterParameter (
484+ adapter_type = adapter_type ,
485+ adapter_strength = adapter_strength ,
486+ adapter_init_type = adapter_init_type ,
487+ )
488+
467489 transform = None
468490 if sampler :
469491 transform = generation .TransformType (diffusion = sampler )
@@ -475,17 +497,25 @@ def _make_generate_request(
475497 seed = seed ,
476498 steps = steps ,
477499 samples = samples ,
500+ adapter = adapter_parameters ,
478501 parameters = [generation .StepParameter (** step_parameters )],
479502 )
480503
481504 request_id = str (uuid .uuid4 ())
482505 engine_id = self .engine
483506
507+ if style_preset and style_preset .lower () != 'none' :
508+ extras = Struct ()
509+ extras .update ({ '$IPC' : { "preset" : style_preset } })
510+ else :
511+ extras = None
512+
484513 rq = generation .Request (
485514 engine_id = engine_id ,
486515 request_id = request_id ,
487516 prompt = prompts ,
488517 image = image_parameters ,
518+ extras = extras
489519 )
490520
491521 return rq
@@ -641,9 +671,10 @@ def run_request(self,
641671 yield answer
642672 start = time .time ()
643673
644- def process_cli (logger : logging .Logger = None ,
645- warn_client_call_deprecated : bool = True ,
646- ):
674+ def process_cli (
675+ logger : logging .Logger = None ,
676+ warn_client_call_deprecated : bool = True ,
677+ ):
647678 if not logger :
648679 logger = logging .getLogger (__name__ )
649680 logger .setLevel (level = logging .INFO )
@@ -742,10 +773,10 @@ def process_cli(logger: logging.Logger = None,
742773
743774 parser_generate = subparsers .add_parser ('generate' )
744775 parser_generate .add_argument (
745- "--height" , "-H" , type = int , default = 512 , help = "[512 ] height of image"
776+ "--height" , "-H" , type = int , default = 1024 , help = "[1024 ] height of image"
746777 )
747778 parser_generate .add_argument (
748- "--width" , "-W" , type = int , default = 512 , help = "[512 ] width of image"
779+ "--width" , "-W" , type = int , default = 1024 , help = "[1024 ] width of image"
749780 )
750781 parser_generate .add_argument (
751782 "--start_schedule" ,
@@ -773,6 +804,7 @@ def process_cli(logger: logging.Logger = None,
773804 )
774805 parser_generate .add_argument (
775806 "--seed" , "-S" , type = int , default = 0 , help = "random seed to use" )
807+ parser_generate .add_argument ("--style_preset" , type = str , help = "style preset name" )
776808 parser_generate .add_argument (
777809 "--prefix" ,
778810 "-p" ,
@@ -799,7 +831,7 @@ def process_cli(logger: logging.Logger = None,
799831 "-e" ,
800832 type = str ,
801833 help = "engine to use for inference" ,
802- default = "stable-diffusion-xl-beta-v2-2-2 " ,
834+ default = "stable-diffusion-xl-1024-v1-0 " ,
803835 )
804836 parser_generate .add_argument (
805837 "--init_image" ,
@@ -908,6 +940,7 @@ def process_cli(logger: logging.Logger = None,
908940 "mask_image" : args .mask_image ,
909941 "upscale" : upscale ,
910942 "enhance" : enhance ,
943+ "style_preset" : args .style_preset ,
911944 }
912945
913946 if args .sampler :
0 commit comments