@@ -88,6 +88,7 @@ def __init__(
88
88
clip_sample : bool = True ,
89
89
set_alpha_to_one : bool = True ,
90
90
tensor_format : str = "pt" ,
91
+ prediction_type : str = "epsilon"
91
92
):
92
93
if trained_betas is not None :
93
94
self .betas = np .asarray (trained_betas )
@@ -115,6 +116,7 @@ def __init__(
115
116
self .clip_sample = clip_sample
116
117
self .set_alpha_to_one = set_alpha_to_one
117
118
self .tensor_format = tensor_format
119
+ self .prediction_type = prediction_type
118
120
119
121
# At every step in ddim, we are looking into the previous alphas_cumprod
120
122
# For the final step, there is no previous alphas_cumprod because we are already at 0
@@ -217,8 +219,14 @@ def step(
217
219
218
220
# 3. compute predicted original sample from predicted noise also called
219
221
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
220
- pred_original_sample = (sample - beta_prod_t ** (0.5 ) * model_output ) / alpha_prod_t ** (0.5 )
221
-
222
+ if self .config .prediction_type == "epsilon" :
223
+ pred_original_sample = (sample - beta_prod_t ** (0.5 ) * model_output ) / alpha_prod_t ** (0.5 )
224
+ pred_epsilon = model_output
225
+ elif self .config .prediction_type == "v_prediction" :
226
+ pred_original_sample = (alpha_prod_t ** 0.5 ) * sample - (beta_prod_t ** 0.5 ) * model_output
227
+ pred_epsilon = (alpha_prod_t ** 0.5 ) * model_output + (beta_prod_t ** 0.5 ) * sample
228
+ else :
229
+ raise ValueError ("Unknown prediction_type" )
222
230
# 4. Clip "predicted x_0"
223
231
if self .config .clip_sample :
224
232
pred_original_sample = self .clip (pred_original_sample , - 1 , 1 )
@@ -230,10 +238,10 @@ def step(
230
238
231
239
if use_clipped_model_output :
232
240
# the model_output is always re-derived from the clipped x_0 in Glide
233
- model_output = (sample - alpha_prod_t ** (0.5 ) * pred_original_sample ) / beta_prod_t ** (0.5 )
241
+ pred_epsilon = (sample - alpha_prod_t ** (0.5 ) * pred_original_sample ) / beta_prod_t ** (0.5 )
234
242
235
243
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
236
- pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2 ) ** (0.5 ) * model_output
244
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2 ) ** (0.5 ) * pred_epsilon
237
245
238
246
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
239
247
prev_sample = alpha_prod_t_prev ** (0.5 ) * pred_original_sample + pred_sample_direction
0 commit comments