Skip to content

Commit b2c4087

Browse files
committed
support for flux controls
1 parent 712b807 commit b2c4087

File tree

4 files changed

+58
-35
lines changed

4 files changed

+58
-35
lines changed

flux.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,14 @@ namespace Flux {
845845
control = patchify(ctx, control, patch_size);
846846

847847
img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
848+
} else if (version == VERSION_FLUX_CONTROLS) {
849+
GGML_ASSERT(c_concat != NULL);
850+
851+
ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);
852+
853+
control = patchify(ctx, control, patch_size);
854+
855+
img = ggml_concat(ctx, img, control, 0);
848856
}
849857

850858
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]
@@ -877,6 +885,8 @@ namespace Flux {
877885
flux_params.depth_single_blocks = 0;
878886
if (version == VERSION_FLUX_FILL) {
879887
flux_params.in_channels = 384;
888+
} else if (version == VERSION_FLUX_CONTROLS) {
889+
flux_params.in_channels = 128;
880890
} else if (version == VERSION_FLEX_2) {
881891
flux_params.in_channels = 196;
882892
}

model.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,10 +1551,12 @@ SDVersion ModelLoader::get_sd_version() {
15511551
}
15521552

15531553
if (is_flux) {
1554-
is_inpaint = input_block_weight.ne[0] == 384;
1555-
if (is_inpaint) {
1554+
if (input_block_weight.ne[0] == 384) {
15561555
return VERSION_FLUX_FILL;
15571556
}
1557+
if (input_block_weight.ne[0] == 128) {
1558+
return VERSION_FLUX_CONTROLS;
1559+
}
15581560
if(input_block_weight.ne[0] == 196){
15591561
return VERSION_FLEX_2;
15601562
}

model.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ enum SDVersion {
3131
VERSION_SD3,
3232
VERSION_FLUX,
3333
VERSION_FLUX_FILL,
34+
VERSION_FLUX_CONTROLS,
3435
VERSION_FLEX_2,
3536
VERSION_COUNT,
3637
};
3738

3839
static inline bool sd_version_is_flux(SDVersion version) {
39-
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2 ) {
40+
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2 ) {
4041
return true;
4142
}
4243
return false;
@@ -70,15 +71,16 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
7071
return false;
7172
}
7273

73-
static inline bool sd_version_is_inpaint(SDVersion version) {
74-
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
74+
75+
static inline bool sd_version_is_dit(SDVersion version) {
76+
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
7577
return true;
7678
}
7779
return false;
7880
}
7981

80-
static inline bool sd_version_is_dit(SDVersion version) {
81-
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
82+
static inline bool sd_version_is_inpaint(SDVersion version) {
83+
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
8284
return true;
8385
}
8486
return false;
@@ -88,8 +90,12 @@ static inline bool sd_version_is_edit(SDVersion version) {
8890
return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX;
8991
}
9092

93+
static inline bool sd_version_is_control(SDVersion version) {
94+
return version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2;
95+
}
96+
9197
static bool sd_version_use_concat(SDVersion version) {
92-
return sd_version_is_edit(version) || sd_version_is_inpaint(version);
98+
return sd_version_is_edit(version) || sd_version_is_inpaint(version)|| sd_version_is_control(version);
9399
}
94100

95101
enum PMVersion {

stable-diffusion.cpp

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ class StableDiffusionGGML {
301301
// TODO: shift_factor
302302
}
303303

304-
if(version == VERSION_FLEX_2){
304+
if (sd_version_is_control(version)) {
305305
// Might need vae encode for control cond
306306
vae_decode_only = false;
307307
}
@@ -815,15 +815,15 @@ class StableDiffusionGGML {
815815
const std::vector<float>& sigmas,
816816
int start_merge_step,
817817
SDCondition id_cond,
818-
ggml_tensor* denoise_mask = NULL) {
818+
ggml_tensor* denoise_mask = NULL) {
819819
std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count);
820820

821821
// TODO (Pix2Pix): separate image guidance params (right now it's reusing distilled guidance)
822822

823-
float cfg_scale = guidance.txt_cfg;
823+
float cfg_scale = guidance.txt_cfg;
824824
float img_cfg_scale = guidance.img_cfg;
825-
float slg_scale = guidance.slg.scale;
826-
825+
float slg_scale = guidance.slg.scale;
826+
827827
float min_cfg = guidance.min_cfg;
828828

829829
if (img_cfg_scale != cfg_scale && !sd_version_use_concat(version)) {
@@ -1475,6 +1475,17 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
14751475
int W = width / 8;
14761476
int H = height / 8;
14771477
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
1478+
1479+
struct ggml_tensor* control_latent = NULL;
1480+
if (sd_version_is_control(sd_ctx->sd->version) && image_hint != NULL) {
1481+
if (!sd_ctx->sd->use_tiny_autoencoder) {
1482+
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1483+
control_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
1484+
} else {
1485+
control_latent = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1486+
}
1487+
}
1488+
14781489
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
14791490
int64_t mask_channels = 1;
14801491
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
@@ -1507,50 +1518,44 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
15071518
}
15081519
}
15091520
}
1510-
if (sd_ctx->sd->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd->control_net == NULL) {
1521+
1522+
if (sd_ctx->sd->version == VERSION_FLEX_2 && control_latent != NULL && sd_ctx->sd->control_net == NULL) {
15111523
bool no_inpaint = concat_latent == NULL;
15121524
if (no_inpaint) {
15131525
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
15141526
}
15151527
// fill in the control image here
1516-
struct ggml_tensor* control_latents = NULL;
1517-
if (!sd_ctx->sd->use_tiny_autoencoder) {
1518-
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1519-
control_latents = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
1520-
} else {
1521-
control_latents = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1522-
}
1523-
for (int64_t x = 0; x < concat_latent->ne[0]; x++) {
1524-
for (int64_t y = 0; y < concat_latent->ne[1]; y++) {
1528+
for (int64_t x = 0; x < control_latent->ne[0]; x++) {
1529+
for (int64_t y = 0; y < control_latent->ne[1]; y++) {
15251530
if (no_inpaint) {
1526-
for (int64_t c = 0; c < concat_latent->ne[2] - control_latents->ne[2]; c++) {
1531+
for (int64_t c = 0; c < concat_latent->ne[2] - control_latent->ne[2]; c++) {
15271532
// 0x16,1x1,0x16
15281533
ggml_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c);
15291534
}
15301535
}
1531-
for (int64_t c = 0; c < control_latents->ne[2]; c++) {
1532-
float v = ggml_tensor_get_f32(control_latents, x, y, c);
1533-
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latents->ne[2] + c);
1536+
for (int64_t c = 0; c < control_latent->ne[2]; c++) {
1537+
float v = ggml_tensor_get_f32(control_latent, x, y, c);
1538+
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latent->ne[2] + c);
15341539
}
15351540
}
15361541
}
1537-
// Disable controlnet
1538-
image_hint = NULL;
15391542
} else if (concat_latent == NULL) {
15401543
concat_latent = empty_latent;
15411544
}
15421545
cond.c_concat = concat_latent;
15431546
uncond.c_concat = empty_latent;
15441547
denoise_mask = NULL;
1545-
} else if (sd_version_is_edit(sd_ctx->sd->version)) {
1548+
} else if (sd_version_is_edit(sd_ctx->sd->version) || sd_version_is_control(sd_ctx->sd->version)) {
15461549
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], init_latent->ne[3]);
15471550
ggml_set_f32(empty_latent, 0);
15481551
uncond.c_concat = empty_latent;
1552+
if (sd_version_is_control(sd_ctx->sd->version) && control_latent != NULL && sd_ctx->sd->control_net == NULL) {
1553+
concat_latent = control_latent;
1554+
}
15491555
if (concat_latent == NULL) {
15501556
concat_latent = empty_latent;
15511557
}
1552-
cond.c_concat = concat_latent;
1553-
1558+
cond.c_concat = concat_latent;
15541559
}
15551560
for (int b = 0; b < batch_count; b++) {
15561561
int64_t sampling_start = ggml_time_ms();
@@ -1823,7 +1828,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
18231828
ggml_tensor* masked_latent = NULL;
18241829
if (!sd_ctx->sd->use_tiny_autoencoder) {
18251830
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1826-
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
1831+
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
18271832
} else {
18281833
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
18291834
}
@@ -1894,8 +1899,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
18941899
} else {
18951900
concat_latent = init_latent;
18961901
}
1897-
}
1898-
1902+
}
1903+
18991904
{
19001905
// LOG_WARN("Inpainting with a base model is not great");
19011906
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1);

0 commit comments

Comments
 (0)