Skip to content

Commit 8af38e2

Browse files
committed
v prediction for sd2
1 parent 4f507cc commit 8af38e2

File tree

3 files changed

+31
-10
lines changed

3 files changed

+31
-10
lines changed

backends/stable_diffusion/schedulers/scheduling_ddim.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
clip_sample: bool = True,
8989
set_alpha_to_one: bool = True,
9090
tensor_format: str = "pt",
91+
prediction_type: str = "epsilon"
9192
):
9293
if trained_betas is not None:
9394
self.betas = np.asarray(trained_betas)
@@ -115,6 +116,7 @@ def __init__(
115116
self.clip_sample = clip_sample
116117
self.set_alpha_to_one = set_alpha_to_one
117118
self.tensor_format = tensor_format
119+
self.prediction_type = prediction_type
118120

119121
# At every step in ddim, we are looking into the previous alphas_cumprod
120122
# For the final step, there is no previous alphas_cumprod because we are already at 0
@@ -217,8 +219,14 @@ def step(
217219

218220
# 3. compute predicted original sample from predicted noise also called
219221
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
220-
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
221-
222+
if self.config.prediction_type == "epsilon":
223+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
224+
pred_epsilon = model_output
225+
elif self.config.prediction_type == "v_prediction":
226+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
227+
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
228+
else:
229+
raise ValueError("Unknown prediction_type")
222230
# 4. Clip "predicted x_0"
223231
if self.config.clip_sample:
224232
pred_original_sample = self.clip(pred_original_sample, -1, 1)
@@ -230,10 +238,10 @@ def step(
230238

231239
if use_clipped_model_output:
232240
# the model_output is always re-derived from the clipped x_0 in Glide
233-
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
241+
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
234242

235243
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
236-
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
244+
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
237245

238246
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
239247
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

backends/stable_diffusion/stable_diffusion.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,21 @@ def get_scheduler(name):
111111
set_alpha_to_one=False,
112112
# steps_offset= 1,
113113
trained_betas= None,
114-
tensor_format="np"
114+
tensor_format="np",
115+
)
116+
117+
if name == "ddim_v":
118+
return DDIMScheduler(
119+
beta_start=0.00085,
120+
beta_end=0.012,
121+
beta_schedule="scaled_linear",
122+
clip_sample= False,
123+
num_train_timesteps= 1000,
124+
set_alpha_to_one=False,
125+
# steps_offset= 1,
126+
trained_betas= None,
127+
tensor_format="np",
128+
prediction_type="v_prediction"
115129
)
116130

117131
if name == "lmsd":

backends/stable_diffusion/tests.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,15 +287,14 @@ def test_sd2_2():
287287
def test_sd2_4():
288288

289289
img = sd.generate(
290-
prompt="A Tree" ,
290+
prompt="a tree" ,
291291
img_height=512,
292292
img_width=512,
293-
seed=1,
294-
num_steps=10,
293+
seed=13,
294+
num_steps=30,
295295
tdict_path="/Volumes/ext_drive_1/sd_data_models/v2-1_768-nonema-pruned.tdict",
296-
batch_size=1,
297296
dtype="float32",
298-
scheduler='pndm',
297+
scheduler='ddim_v',
299298
mode="txt2img" )
300299

301300
# sd2_a_cat_111_test_sd2_3.png

0 commit comments

Comments
 (0)