Skip to content

Commit 712b807

Browse files
committed
Support Flex-2
1 parent 76eee90 commit 712b807

File tree

6 files changed

+106
-11
lines changed

6 files changed

+106
-11
lines changed

examples/cli/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -933,7 +933,7 @@ int main(int argc, const char* argv[]) {
933933
}
934934

935935
sd_image_t* control_image = NULL;
936-
if (params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) {
936+
if (params.control_image_path.size() > 0) {
937937
int c = 0;
938938
control_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
939939
if (control_image_buffer == NULL) {

flux.hpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,8 @@ namespace Flux {
793793
struct ggml_tensor* y,
794794
struct ggml_tensor* guidance,
795795
struct ggml_tensor* pe,
796-
std::vector<int> skip_layers = std::vector<int>()) {
796+
std::vector<int> skip_layers = std::vector<int>(),
797+
SDVersion version = VERSION_FLUX) {
797798
// Forward pass of DiT.
798799
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
799800
// timestep: (N,) tensor of diffusion timesteps
@@ -817,7 +818,8 @@ namespace Flux {
817818
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
818819
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
819820

820-
if (c_concat != NULL) {
821+
if (version == VERSION_FLUX_FILL) {
822+
GGML_ASSERT(c_concat != NULL);
821823
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
822824
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
823825

@@ -828,6 +830,21 @@ namespace Flux {
828830
mask = patchify(ctx, mask, patch_size);
829831

830832
img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
833+
} else if (version == VERSION_FLEX_2) {
834+
GGML_ASSERT(c_concat != NULL);
835+
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
836+
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
837+
ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
838+
839+
masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
840+
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
841+
control = ggml_pad(ctx, control, pad_w, pad_h, 0, 0);
842+
843+
masked = patchify(ctx, masked, patch_size);
844+
mask = patchify(ctx, mask, patch_size);
845+
control = patchify(ctx, control, patch_size);
846+
847+
img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
831848
}
832849

833850
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]
@@ -846,19 +863,22 @@ namespace Flux {
846863
FluxParams flux_params;
847864
Flux flux;
848865
std::vector<float> pe_vec; // for cache
866+
SDVersion version;
849867

850868
FluxRunner(ggml_backend_t backend,
851869
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
852870
const std::string prefix = "",
853871
SDVersion version = VERSION_FLUX,
854872
bool flash_attn = false)
855-
: GGMLRunner(backend) {
873+
: GGMLRunner(backend), version(version) {
856874
flux_params.flash_attn = flash_attn;
857875
flux_params.guidance_embed = false;
858876
flux_params.depth = 0;
859877
flux_params.depth_single_blocks = 0;
860878
if (version == VERSION_FLUX_FILL) {
861879
flux_params.in_channels = 384;
880+
} else if (version == VERSION_FLEX_2) {
881+
flux_params.in_channels = 196;
862882
}
863883
for (auto pair : tensor_types) {
864884
std::string tensor_name = pair.first;
@@ -941,7 +961,8 @@ namespace Flux {
941961
y,
942962
guidance,
943963
pe,
944-
skip_layers);
964+
skip_layers,
965+
version);
945966

946967
ggml_build_forward_expand(gf, out);
947968

model.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,6 +1555,9 @@ SDVersion ModelLoader::get_sd_version() {
15551555
if (is_inpaint) {
15561556
return VERSION_FLUX_FILL;
15571557
}
1558+
if(input_block_weight.ne[0] == 196){
1559+
return VERSION_FLEX_2;
1560+
}
15581561
return VERSION_FLUX;
15591562
}
15601563

model.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@ enum SDVersion {
3131
VERSION_SD3,
3232
VERSION_FLUX,
3333
VERSION_FLUX_FILL,
34+
VERSION_FLEX_2,
3435
VERSION_COUNT,
3536
};
3637

3738
static inline bool sd_version_is_flux(SDVersion version) {
38-
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
39+
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2 ) {
3940
return true;
4041
}
4142
return false;
@@ -70,7 +71,7 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
7071
}
7172

7273
static inline bool sd_version_is_inpaint(SDVersion version) {
73-
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
74+
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
7475
return true;
7576
}
7677
return false;

stable-diffusion.cpp

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class StableDiffusionGGML {
9595
std::shared_ptr<DiffusionModel> diffusion_model;
9696
std::shared_ptr<AutoEncoderKL> first_stage_model;
9797
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
98-
std::shared_ptr<ControlNet> control_net;
98+
std::shared_ptr<ControlNet> control_net = NULL;
9999
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
100100
std::shared_ptr<LoraModel> pmid_lora;
101101
std::shared_ptr<PhotoMakerIDEmbed> pmid_id_embeds;
@@ -301,6 +301,11 @@ class StableDiffusionGGML {
301301
// TODO: shift_factor
302302
}
303303

304+
if(version == VERSION_FLEX_2){
305+
// Might need vae encode for control cond
306+
vae_decode_only = false;
307+
}
308+
304309
if (version == VERSION_SVD) {
305310
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend, model_loader.tensor_storages_types);
306311
clip_vision->alloc_params_buffer();
@@ -897,7 +902,7 @@ class StableDiffusionGGML {
897902

898903
std::vector<struct ggml_tensor*> controls;
899904

900-
if (control_hint != NULL) {
905+
if (control_hint != NULL && control_net != NULL) {
901906
control_net->compute(n_threads, noised_input, control_hint, timesteps, cond.c_crossattn, cond.c_vector);
902907
controls = control_net->controls;
903908
// print_ggml_tensor(controls[12]);
@@ -934,7 +939,7 @@ class StableDiffusionGGML {
934939
float* negative_data = NULL;
935940
if (has_unconditioned) {
936941
// uncond
937-
if (control_hint != NULL) {
942+
if (control_hint != NULL && control_net != NULL) {
938943
control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector);
939944
controls = control_net->controls;
940945
}
@@ -1474,6 +1479,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
14741479
int64_t mask_channels = 1;
14751480
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
14761481
mask_channels = 8 * 8; // flatten the whole mask
1482+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1483+
mask_channels = 1 + init_latent->ne[2];
14771484
}
14781485
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
14791486
// no mask, set the whole image as masked
@@ -1487,6 +1494,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
14871494
for (int64_t c = init_latent->ne[2]; c < empty_latent->ne[2]; c++) {
14881495
ggml_tensor_set_f32(empty_latent, 1, x, y, c);
14891496
}
1497+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1498+
for (int64_t c = 0; c < empty_latent->ne[2]; c++) {
1499+
// 0x16,1x1,0x16
1500+
ggml_tensor_set_f32(empty_latent, c == init_latent->ne[2], x, y, c);
1501+
}
14901502
} else {
14911503
ggml_tensor_set_f32(empty_latent, 1, x, y, 0);
14921504
for (int64_t c = 1; c < empty_latent->ne[2]; c++) {
@@ -1495,7 +1507,36 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
14951507
}
14961508
}
14971509
}
1498-
if (concat_latent == NULL) {
1510+
if (sd_ctx->sd->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd->control_net == NULL) {
1511+
bool no_inpaint = concat_latent == NULL;
1512+
if (no_inpaint) {
1513+
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
1514+
}
1515+
// fill in the control image here
1516+
struct ggml_tensor* control_latents = NULL;
1517+
if (!sd_ctx->sd->use_tiny_autoencoder) {
1518+
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1519+
control_latents = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
1520+
} else {
1521+
control_latents = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1522+
}
1523+
for (int64_t x = 0; x < concat_latent->ne[0]; x++) {
1524+
for (int64_t y = 0; y < concat_latent->ne[1]; y++) {
1525+
if (no_inpaint) {
1526+
for (int64_t c = 0; c < concat_latent->ne[2] - control_latents->ne[2]; c++) {
1527+
// 0x16,1x1,0x16
1528+
ggml_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c);
1529+
}
1530+
}
1531+
for (int64_t c = 0; c < control_latents->ne[2]; c++) {
1532+
float v = ggml_tensor_get_f32(control_latents, x, y, c);
1533+
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latents->ne[2] + c);
1534+
}
1535+
}
1536+
}
1537+
// Disable controlnet
1538+
image_hint = NULL;
1539+
} else if (concat_latent == NULL) {
14991540
concat_latent = empty_latent;
15001541
}
15011542
cond.c_concat = concat_latent;
@@ -1772,6 +1813,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
17721813
int64_t mask_channels = 1;
17731814
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
17741815
mask_channels = 8 * 8; // flatten the whole mask
1816+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1817+
mask_channels = 1 + init_latent->ne[2];
17751818
}
17761819
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
17771820
// Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
@@ -1803,6 +1846,32 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
18031846
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * 8 + y);
18041847
}
18051848
}
1849+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1850+
float m = ggml_tensor_get_f32(mask_img, mx, my);
1851+
// masked image
1852+
for (int k = 0; k < masked_latent->ne[2]; k++) {
1853+
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
1854+
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
1855+
}
1856+
// downsampled mask
1857+
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]);
1858+
// control (todo: support this)
1859+
for (int k = 0; k < masked_latent->ne[2]; k++) {
1860+
ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k);
1861+
}
1862+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1863+
float m = ggml_tensor_get_f32(mask_img, mx, my);
1864+
// masked image
1865+
for (int k = 0; k < masked_latent->ne[2]; k++) {
1866+
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
1867+
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
1868+
}
1869+
// downsampled mask
1870+
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]);
1871+
// control (todo: support this)
1872+
for (int k = 0; k < masked_latent->ne[2]; k++) {
1873+
ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k);
1874+
}
18061875
} else {
18071876
float m = ggml_tensor_get_f32(mask_img, mx, my);
18081877
ggml_tensor_set_f32(concat_latent, m, ix, iy, 0);

vae.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ struct AutoEncoderKL : public GGMLRunner {
559559
bool decode_graph,
560560
struct ggml_tensor** output,
561561
struct ggml_context* output_ctx = NULL) {
562+
GGML_ASSERT(!decode_only || decode_graph);
562563
auto get_graph = [&]() -> struct ggml_cgraph* {
563564
return build_graph(z, decode_graph);
564565
};

0 commit comments

Comments
 (0)