From c457adfb0179eea82e32a88d434f8568bfd33ca8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Wed, 16 Jul 2025 14:50:15 +0200 Subject: [PATCH 1/6] Support Flex-2 --- flux.hpp | 28 ++++++++++++-- ggml_extend.hpp | 2 + model.cpp | 3 ++ model.h | 5 ++- stable-diffusion.cpp | 91 +++++++++++++++++++++++++++++++++++++++----- vae.hpp | 1 + 6 files changed, 114 insertions(+), 16 deletions(-) diff --git a/flux.hpp b/flux.hpp index 11045918f..303dc0123 100644 --- a/flux.hpp +++ b/flux.hpp @@ -984,7 +984,8 @@ namespace Flux { struct ggml_tensor* pe, struct ggml_tensor* mod_index_arange = NULL, std::vector ref_latents = {}, - std::vector skip_layers = {}) { + std::vector skip_layers = {}, + SDVersion version = VERSION_FLUX) { // Forward pass of DiT. // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) // timestep: (N,) tensor of diffusion timesteps @@ -1007,7 +1008,8 @@ namespace Flux { auto img = process_img(ctx, x); uint64_t img_tokens = img->ne[1]; - if (c_concat != NULL) { + if (version == VERSION_FLUX_FILL) { + GGML_ASSERT(c_concat != NULL); ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); @@ -1015,6 +1017,21 @@ namespace Flux { mask = process_img(ctx, mask); img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0); + } else if (version == VERSION_FLEX_2) { + GGML_ASSERT(c_concat != NULL); + ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); + ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); + ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1)); + + masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0); + mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0); + control = ggml_pad(ctx, control, pad_w, pad_h, 0, 0); + + masked = patchify(ctx, masked, patch_size); + mask = patchify(ctx, mask, patch_size); + control = patchify(ctx, control, patch_size); + + img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0); } if (ref_latents.size() > 0) { @@ -1055,13 +1072,15 @@ namespace Flux { SDVersion version = VERSION_FLUX, bool flash_attn = false, bool use_mask = false) - : GGMLRunner(backend), use_mask(use_mask) { + : GGMLRunner(backend), version(version), use_mask(use_mask) { flux_params.flash_attn = flash_attn; flux_params.guidance_embed = false; flux_params.depth = 0; flux_params.depth_single_blocks = 0; if (version == VERSION_FLUX_FILL) { flux_params.in_channels = 384; + } else if (version == VERSION_FLEX_2) { + flux_params.in_channels = 196; } for (auto pair : tensor_types) { std::string tensor_name = pair.first; @@ -1171,7 +1190,8 @@ namespace Flux { pe, mod_index_arange, ref_latents, - skip_layers); + skip_layers, + version); ggml_build_forward_expand(gf, out); diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 9f6a4fef6..4a2bab9d1 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -384,6 +384,8 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data, int64_t width = output->ne[0]; int64_t height = output->ne[1]; int64_t channels = output->ne[2]; + float rescale_mx = mask->ne[0]/output->ne[0]; + float rescale_my = mask->ne[1]/output->ne[1]; GGML_ASSERT(output->type == GGML_TYPE_F32); for (int ix = 0; ix < width; ix++) { for (int iy = 0; iy < height; iy++) { diff --git a/model.cpp b/model.cpp index 2e40e004a..c4d82c224 100644 --- a/model.cpp +++ b/model.cpp @@ -1689,6 +1689,9 @@ SDVersion ModelLoader::get_sd_version() { if (is_inpaint) { return VERSION_FLUX_FILL; } + if(input_block_weight.ne[0] == 196){ + return VERSION_FLEX_2; + } return VERSION_FLUX; } diff --git a/model.h b/model.h index a6266039a..fdcf319e8 100644 --- a/model.h +++ b/model.h @@ -31,11 +31,12 @@ enum SDVersion { VERSION_SD3, VERSION_FLUX, VERSION_FLUX_FILL, + VERSION_FLEX_2, VERSION_COUNT, }; static inline bool sd_version_is_flux(SDVersion version) { - if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) { + if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2 ) { return true; } return false; @@ -70,7 +71,7 @@ static inline bool sd_version_is_sdxl(SDVersion version) { } static inline bool sd_version_is_inpaint(SDVersion version) { - if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) { + if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) { return true; } return false; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 402585f1c..4b776a744 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -95,7 +95,7 @@ class StableDiffusionGGML { std::shared_ptr diffusion_model; std::shared_ptr first_stage_model; std::shared_ptr tae_first_stage; - std::shared_ptr control_net; + std::shared_ptr control_net = NULL; std::shared_ptr pmid_model; std::shared_ptr pmid_lora; std::shared_ptr pmid_id_embeds; @@ -297,6 +297,11 @@ class StableDiffusionGGML { // TODO: shift_factor } + if(version == VERSION_FLEX_2){ + // Might need vae encode for control cond + vae_decode_only = false; + } + bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu; if (version == VERSION_SVD) { @@ -933,7 +938,7 @@ class StableDiffusionGGML { std::vector controls; - if (control_hint != NULL) { + if (control_hint != NULL && control_net != NULL) { control_net->compute(n_threads, noised_input, control_hint, timesteps, cond.c_crossattn, cond.c_vector); controls = control_net->controls; // print_ggml_tensor(controls[12]); @@ -972,7 +977,7 @@ class StableDiffusionGGML { float* negative_data = NULL; if (has_unconditioned) { // uncond - if (control_hint != NULL) { + if (control_hint != NULL && control_net != NULL) { control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector); controls = control_net->controls; } @@ -1721,6 +1726,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, int64_t mask_channels = 1; if (sd_ctx->sd->version == VERSION_FLUX_FILL) { mask_channels = 8 * 8; // flatten the whole mask + } else if (sd_ctx->sd->version == VERSION_FLEX_2) { + mask_channels = 1 + init_latent->ne[2]; } auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1); // no mask, set the whole image as masked @@ -1734,6 +1741,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, for (int64_t c = init_latent->ne[2]; c < empty_latent->ne[2]; c++) { ggml_tensor_set_f32(empty_latent, 1, x, y, c); } + } else if (sd_ctx->sd->version == VERSION_FLEX_2) { + for (int64_t c = 0; c < empty_latent->ne[2]; c++) { + // 0x16,1x1,0x16 + ggml_tensor_set_f32(empty_latent, c == init_latent->ne[2], x, y, c); + } } else { ggml_tensor_set_f32(empty_latent, 1, x, y, 0); for (int64_t c = 1; c < empty_latent->ne[2]; c++) { @@ -1742,12 +1754,42 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, } } } - if (concat_latent == NULL) { + if (sd_ctx->sd->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd->control_net == NULL) { + bool no_inpaint = concat_latent == NULL; + if (no_inpaint) { + concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1); + } + // fill in the control image here + struct ggml_tensor* control_latents = NULL; + if (!sd_ctx->sd->use_tiny_autoencoder) { + struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint); + control_latents = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments); + } else { + control_latents = sd_ctx->sd->encode_first_stage(work_ctx, image_hint); + } + for (int64_t x = 0; x < concat_latent->ne[0]; x++) { + for (int64_t y = 0; y < concat_latent->ne[1]; y++) { + if (no_inpaint) { + for (int64_t c = 0; c < concat_latent->ne[2] - control_latents->ne[2]; c++) { + // 0x16,1x1,0x16 + ggml_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c); + } + } + for (int64_t c = 0; c < control_latents->ne[2]; c++) { + float v = ggml_tensor_get_f32(control_latents, x, y, c); + ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latents->ne[2] + c); + } + } + } + // Disable controlnet + image_hint = NULL; + } else if (concat_latent == NULL) { concat_latent = empty_latent; } cond.c_concat = concat_latent; uncond.c_concat = empty_latent; denoise_mask = NULL; + } else if (sd_version_is_unet_edit(sd_ctx->sd->version)) { } else if (sd_version_is_unet_edit(sd_ctx->sd->version)) { auto empty_latent = ggml_dup_tensor(work_ctx, init_latent); ggml_set_f32(empty_latent, 0); @@ -1935,10 +1977,19 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_mask_to_tensor(sd_img_gen_params->mask_image.data, mask_img); sd_image_to_tensor(sd_img_gen_params->init_image.data, init_img); + if (!sd_ctx->sd->use_tiny_autoencoder) { + ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img); + init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); + } else { + init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); + } + if (sd_version_is_inpaint(sd_ctx->sd->version)) { int64_t mask_channels = 1; if (sd_ctx->sd->version == VERSION_FLUX_FILL) { mask_channels = 8 * 8; // flatten the whole mask + } else if (sd_ctx->sd->version == VERSION_FLEX_2) { + mask_channels = 1 + init_latent->ne[2]; } ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); sd_apply_mask(init_img, mask_img, masked_img); @@ -1973,6 +2024,32 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * 8 + y); } } + } else if (sd_ctx->sd->version == VERSION_FLEX_2) { + float m = ggml_tensor_get_f32(mask_img, mx, my); + // masked image + for (int k = 0; k < masked_latent->ne[2]; k++) { + float v = ggml_tensor_get_f32(masked_latent, ix, iy, k); + ggml_tensor_set_f32(concat_latent, v, ix, iy, k); + } + // downsampled mask + ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]); + // control (todo: support this) + for (int k = 0; k < masked_latent->ne[2]; k++) { + ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k); + } + } else if (sd_ctx->sd->version == VERSION_FLEX_2) { + float m = ggml_tensor_get_f32(mask_img, mx, my); + // masked image + for (int k = 0; k < masked_latent->ne[2]; k++) { + float v = ggml_tensor_get_f32(masked_latent, ix, iy, k); + ggml_tensor_set_f32(concat_latent, v, ix, iy, k); + } + // downsampled mask + ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]); + // control (todo: support this) + for (int k = 0; k < masked_latent->ne[2]; k++) { + ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k); + } } else { float m = ggml_tensor_get_f32(mask_img, mx, my); ggml_tensor_set_f32(concat_latent, m, ix, iy, 0); @@ -1998,12 +2075,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g } } - if (!sd_ctx->sd->use_tiny_autoencoder) { - ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img); - init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); - } else { - init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); - } } else { LOG_INFO("TXT2IMG"); if (sd_version_is_inpaint(sd_ctx->sd->version)) { diff --git a/vae.hpp b/vae.hpp index 4add881f6..7ad0a9c3d 100644 --- a/vae.hpp +++ b/vae.hpp @@ -559,6 +559,7 @@ struct AutoEncoderKL : public GGMLRunner { bool decode_graph, struct ggml_tensor** output, struct ggml_context* output_ctx = NULL) { + GGML_ASSERT(!decode_only || decode_graph); auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph(z, decode_graph); }; From 8d5d16a1433a889798ff3fc5515c992c780cd4d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Wed, 16 Jul 2025 14:53:28 +0200 Subject: [PATCH 2/6] support for flux controls --- examples/cli/main.cpp | 3 +- flux.hpp | 10 ++++ model.cpp | 6 ++- model.h | 9 +++- stable-diffusion.cpp | 103 ++++++++++++++++++++---------------------- 5 files changed, 72 insertions(+), 59 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index b3ae569e6..68f90dd49 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -905,7 +905,8 @@ int main(int argc, const char* argv[]) { input_image_buffer}; sd_image_t* control_image = NULL; - if (params.control_net_path.size() > 0 && params.control_image_path.size() > 0) { + if (params.control_image_path.size() > 0) { + printf("load image from '%s'\n", params.control_image_path.c_str()); int c = 0; control_image_buffer = stbi_load(params.control_image_path.c_str(), ¶ms.width, ¶ms.height, &c, 3); if (control_image_buffer == NULL) { diff --git a/flux.hpp b/flux.hpp index 303dc0123..81b1c59ae 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1032,6 +1032,14 @@ namespace Flux { control = patchify(ctx, control, patch_size); img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0); + } else if (version == VERSION_FLUX_CONTROLS) { + GGML_ASSERT(c_concat != NULL); + + ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0); + + control = patchify(ctx, control, patch_size); + + img = ggml_concat(ctx, img, control, 0); } if (ref_latents.size() > 0) { @@ -1079,6 +1087,8 @@ namespace Flux { flux_params.depth_single_blocks = 0; if (version == VERSION_FLUX_FILL) { flux_params.in_channels = 384; + } else if (version == VERSION_FLUX_CONTROLS) { + flux_params.in_channels = 128; } else if (version == VERSION_FLEX_2) { flux_params.in_channels = 196; } diff --git a/model.cpp b/model.cpp index c4d82c224..022be4d2c 100644 --- a/model.cpp +++ b/model.cpp @@ -1685,10 +1685,12 @@ SDVersion ModelLoader::get_sd_version() { } if (is_flux) { - is_inpaint = input_block_weight.ne[0] == 384; - if (is_inpaint) { + if (input_block_weight.ne[0] == 384) { return VERSION_FLUX_FILL; } + if (input_block_weight.ne[0] == 128) { + return VERSION_FLUX_CONTROLS; + } if(input_block_weight.ne[0] == 196){ return VERSION_FLEX_2; } diff --git a/model.h b/model.h index fdcf319e8..409258a79 100644 --- a/model.h +++ b/model.h @@ -31,12 +31,13 @@ enum SDVersion { VERSION_SD3, VERSION_FLUX, VERSION_FLUX_FILL, + VERSION_FLUX_CONTROLS, VERSION_FLEX_2, VERSION_COUNT, }; static inline bool sd_version_is_flux(SDVersion version) { - if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2 ) { + if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2 ) { return true; } return false; @@ -88,8 +89,12 @@ static inline bool sd_version_is_unet_edit(SDVersion version) { return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX; } +static inline bool sd_version_is_control(SDVersion version) { + return version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2; +} + static bool sd_version_is_inpaint_or_unet_edit(SDVersion version) { - return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version); + return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version)|| sd_version_is_control(version); } enum PMVersion { diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 4b776a744..40213bb21 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -297,7 +297,7 @@ class StableDiffusionGGML { // TODO: shift_factor } - if(version == VERSION_FLEX_2){ + if (sd_version_is_control(version)) { // Might need vae encode for control cond vae_decode_only = false; } @@ -1722,6 +1722,17 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, int W = width / 8; int H = height / 8; LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); + + struct ggml_tensor* control_latent = NULL; + if (sd_version_is_control(sd_ctx->sd->version) && image_hint != NULL) { + if (!sd_ctx->sd->use_tiny_autoencoder) { + struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint); + control_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments); + } else { + control_latent = sd_ctx->sd->encode_first_stage(work_ctx, image_hint); + } + } + if (sd_version_is_inpaint(sd_ctx->sd->version)) { int64_t mask_channels = 1; if (sd_ctx->sd->version == VERSION_FLUX_FILL) { @@ -1754,50 +1765,53 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, } } } - if (sd_ctx->sd->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd->control_net == NULL) { + + if (sd_ctx->sd->version == VERSION_FLEX_2 && control_latent != NULL && sd_ctx->sd->control_net == NULL) { bool no_inpaint = concat_latent == NULL; if (no_inpaint) { concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1); } // fill in the control image here - struct ggml_tensor* control_latents = NULL; - if (!sd_ctx->sd->use_tiny_autoencoder) { - struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint); - control_latents = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments); - } else { - control_latents = sd_ctx->sd->encode_first_stage(work_ctx, image_hint); - } - for (int64_t x = 0; x < concat_latent->ne[0]; x++) { - for (int64_t y = 0; y < concat_latent->ne[1]; y++) { + for (int64_t x = 0; x < control_latent->ne[0]; x++) { + for (int64_t y = 0; y < control_latent->ne[1]; y++) { if (no_inpaint) { - for (int64_t c = 0; c < concat_latent->ne[2] - control_latents->ne[2]; c++) { + for (int64_t c = 0; c < concat_latent->ne[2] - control_latent->ne[2]; c++) { // 0x16,1x1,0x16 ggml_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c); } } - for (int64_t c = 0; c < control_latents->ne[2]; c++) { - float v = ggml_tensor_get_f32(control_latents, x, y, c); - ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latents->ne[2] + c); + for (int64_t c = 0; c < control_latent->ne[2]; c++) { + float v = ggml_tensor_get_f32(control_latent, x, y, c); + ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latent->ne[2] + c); } } } - // Disable controlnet - image_hint = NULL; } else if (concat_latent == NULL) { concat_latent = empty_latent; } cond.c_concat = concat_latent; uncond.c_concat = empty_latent; denoise_mask = NULL; - } else if (sd_version_is_unet_edit(sd_ctx->sd->version)) { } else if (sd_version_is_unet_edit(sd_ctx->sd->version)) { auto empty_latent = ggml_dup_tensor(work_ctx, init_latent); ggml_set_f32(empty_latent, 0); uncond.c_concat = empty_latent; - if (concat_latent == NULL) { - concat_latent = empty_latent; + cond.c_concat = ref_latents[0]; + if (cond.c_concat == NULL) { + cond.c_concat = empty_latent; + } + } else if (sd_version_is_control(sd_ctx->sd->version)) { + LOG_DEBUG("HERE"); + auto empty_latent = ggml_dup_tensor(work_ctx, init_latent); + ggml_set_f32(empty_latent, 0); + uncond.c_concat = empty_latent; + if (sd_version_is_control(sd_ctx->sd->version) && control_latent != NULL && sd_ctx->sd->control_net == NULL) { + cond.c_concat = control_latent; } - cond.c_concat = ref_latents[0]; + if (cond.c_concat == NULL) { + cond.c_concat = empty_latent; + } + LOG_DEBUG("HERE"); } SDCondition img_cond; if (uncond.c_crossattn != NULL && @@ -1956,6 +1970,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g size_t t0 = ggml_time_ms(); ggml_tensor* init_latent = NULL; + ggml_tensor* init_moments = NULL; ggml_tensor* concat_latent = NULL; ggml_tensor* denoise_mask = NULL; std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_img_gen_params->sample_steps); @@ -1978,8 +1993,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_image_to_tensor(sd_img_gen_params->init_image.data, init_img); if (!sd_ctx->sd->use_tiny_autoencoder) { - ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img); - init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); + init_moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img); + init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, init_moments); } else { init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); } @@ -1988,8 +2003,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g int64_t mask_channels = 1; if (sd_ctx->sd->version == VERSION_FLUX_FILL) { mask_channels = 8 * 8; // flatten the whole mask - } else if (sd_ctx->sd->version == VERSION_FLEX_2) { - mask_channels = 1 + init_latent->ne[2]; + } else if (sd_ctx->sd->version == VERSION_FLEX_2) { + mask_channels = 1 + init_latent->ne[2]; } ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); sd_apply_mask(init_img, mask_img, masked_img); @@ -2024,38 +2039,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * 8 + y); } } - } else if (sd_ctx->sd->version == VERSION_FLEX_2) { - float m = ggml_tensor_get_f32(mask_img, mx, my); - // masked image - for (int k = 0; k < masked_latent->ne[2]; k++) { - float v = ggml_tensor_get_f32(masked_latent, ix, iy, k); - ggml_tensor_set_f32(concat_latent, v, ix, iy, k); - } - // downsampled mask - ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]); - // control (todo: support this) - for (int k = 0; k < masked_latent->ne[2]; k++) { - ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k); - } - } else if (sd_ctx->sd->version == VERSION_FLEX_2) { - float m = ggml_tensor_get_f32(mask_img, mx, my); - // masked image - for (int k = 0; k < masked_latent->ne[2]; k++) { - float v = ggml_tensor_get_f32(masked_latent, ix, iy, k); - ggml_tensor_set_f32(concat_latent, v, ix, iy, k); - } - // downsampled mask - ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]); - // control (todo: support this) - for (int k = 0; k < masked_latent->ne[2]; k++) { - ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k); - } - } else { + } else if (sd_ctx->sd->version == VERSION_FLEX_2) { float m = ggml_tensor_get_f32(mask_img, mx, my); - ggml_tensor_set_f32(concat_latent, m, ix, iy, 0); + // masked image for (int k = 0; k < masked_latent->ne[2]; k++) { float v = ggml_tensor_get_f32(masked_latent, ix, iy, k); - ggml_tensor_set_f32(concat_latent, v, ix, iy, k + mask_channels); + ggml_tensor_set_f32(concat_latent, v, ix, iy, k); + } + // downsampled mask + ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]); + // control (todo: support this) + for (int k = 0; k < masked_latent->ne[2]; k++) { + ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k); } } } From 2b6d9b162065c09e38a934f3066e86ee240d016d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Wed, 16 Jul 2025 14:56:45 +0200 Subject: [PATCH 3/6] Fix Flex 2 inpaint --- ggml_extend.hpp | 12 ++++++++---- stable-diffusion.cpp | 19 +++++++++++++------ 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 4a2bab9d1..da691f3eb 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -380,7 +380,8 @@ __STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data, __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data, struct ggml_tensor* mask, - struct ggml_tensor* output) { + struct ggml_tensor* output, + float masked_value = 0.5f) { int64_t width = output->ne[0]; int64_t height = output->ne[1]; int64_t channels = output->ne[2]; @@ -389,11 +390,14 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data, GGML_ASSERT(output->type == GGML_TYPE_F32); for (int ix = 0; ix < width; ix++) { for (int iy = 0; iy < height; iy++) { - float m = ggml_tensor_get_f32(mask, ix, iy); + int mx = (int)(ix * rescale_mx); + int my = (int)(iy * rescale_my); + float m = ggml_tensor_get_f32(mask, mx, my); m = round(m); // inpaint models need binary masks - ggml_tensor_set_f32(mask, m, ix, iy); + ggml_tensor_set_f32(mask, m, mx, my); for (int k = 0; k < channels; k++) { - float value = (1 - m) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5; + float value = ggml_tensor_get_f32(image_data, ix, iy, k); + value = (1 - m) * (value - masked_value) + masked_value; ggml_tensor_set_f32(output, value, ix, iy, k); } } diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 40213bb21..57d05af50 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -2006,14 +2006,21 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g } else if (sd_ctx->sd->version == VERSION_FLEX_2) { mask_channels = 1 + init_latent->ne[2]; } - ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); - sd_apply_mask(init_img, mask_img, masked_img); ggml_tensor* masked_latent = NULL; - if (!sd_ctx->sd->use_tiny_autoencoder) { - ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); - masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); + if (sd_ctx->sd->version != VERSION_FLEX_2) { + // most inpaint models mask before vae + ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); + sd_apply_mask(init_img, mask_img, masked_img); + if (!sd_ctx->sd->use_tiny_autoencoder) { + ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); + masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); + } else { + masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); + } } else { - masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); + // mask after vae + masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1); + sd_apply_mask(init_latent, mask_img, masked_latent, 0.); } concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, From f9ba12cd2f0d95a54fbe3280d4952f413c4d2795 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Wed, 16 Jul 2025 15:57:24 +0200 Subject: [PATCH 4/6] Fix model_version_to_str --- stable-diffusion.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 57d05af50..1586ab1e9 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -36,7 +36,10 @@ const char* model_version_to_str[] = { "SVD", "SD3.x", "Flux", - "Flux Fill"}; + "Flux Fill", + "Flux Control", + "Flex.2", +}; const char* sampling_methods_str[] = { "Euler A", From cc15e39278432d2129f50a027fd9daaa7da5852d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Wed, 16 Jul 2025 15:58:11 +0200 Subject: [PATCH 5/6] Support control strength for builtin control models --- stable-diffusion.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 1586ab1e9..82b67f8fd 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1734,6 +1734,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, } else { control_latent = sd_ctx->sd->encode_first_stage(work_ctx, image_hint); } + ggml_tensor_scale(control_latent, control_strength); } if (sd_version_is_inpaint(sd_ctx->sd->version)) { From 2c9e1a2d6af917e5a2476dfe4a9bf55c701a76d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Wed, 16 Jul 2025 16:01:16 +0200 Subject: [PATCH 6/6] small cleanup --- stable-diffusion.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 82b67f8fd..28022e621 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1805,17 +1805,15 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, cond.c_concat = empty_latent; } } else if (sd_version_is_control(sd_ctx->sd->version)) { - LOG_DEBUG("HERE"); auto empty_latent = ggml_dup_tensor(work_ctx, init_latent); ggml_set_f32(empty_latent, 0); uncond.c_concat = empty_latent; - if (sd_version_is_control(sd_ctx->sd->version) && control_latent != NULL && sd_ctx->sd->control_net == NULL) { + if (sd_ctx->sd->control_net == NULL) { cond.c_concat = control_latent; } if (cond.c_concat == NULL) { cond.c_concat = empty_latent; } - LOG_DEBUG("HERE"); } SDCondition img_cond; if (uncond.c_crossattn != NULL &&