@@ -95,7 +95,7 @@ class StableDiffusionGGML {
95
95
std::shared_ptr<DiffusionModel> diffusion_model;
96
96
std::shared_ptr<AutoEncoderKL> first_stage_model;
97
97
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
98
- std::shared_ptr<ControlNet> control_net;
98
+ std::shared_ptr<ControlNet> control_net = NULL ;
99
99
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
100
100
std::shared_ptr<LoraModel> pmid_lora;
101
101
std::shared_ptr<PhotoMakerIDEmbed> pmid_id_embeds;
@@ -301,6 +301,11 @@ class StableDiffusionGGML {
301
301
// TODO: shift_factor
302
302
}
303
303
304
+ if (version == VERSION_FLEX_2){
305
+ // Might need vae encode for control cond
306
+ vae_decode_only = false ;
307
+ }
308
+
304
309
if (version == VERSION_SVD) {
305
310
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend, model_loader.tensor_storages_types );
306
311
clip_vision->alloc_params_buffer ();
@@ -897,7 +902,7 @@ class StableDiffusionGGML {
897
902
898
903
std::vector<struct ggml_tensor *> controls;
899
904
900
- if (control_hint != NULL ) {
905
+ if (control_hint != NULL && control_net != NULL ) {
901
906
control_net->compute (n_threads, noised_input, control_hint, timesteps, cond.c_crossattn , cond.c_vector );
902
907
controls = control_net->controls ;
903
908
// print_ggml_tensor(controls[12]);
@@ -934,7 +939,7 @@ class StableDiffusionGGML {
934
939
float * negative_data = NULL ;
935
940
if (has_unconditioned) {
936
941
// uncond
937
- if (control_hint != NULL ) {
942
+ if (control_hint != NULL && control_net != NULL ) {
938
943
control_net->compute (n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn , uncond.c_vector );
939
944
controls = control_net->controls ;
940
945
}
@@ -1474,6 +1479,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1474
1479
int64_t mask_channels = 1 ;
1475
1480
if (sd_ctx->sd ->version == VERSION_FLUX_FILL) {
1476
1481
mask_channels = 8 * 8 ; // flatten the whole mask
1482
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1483
+ mask_channels = 1 + init_latent->ne [2 ];
1477
1484
}
1478
1485
auto empty_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, init_latent->ne [0 ], init_latent->ne [1 ], mask_channels + init_latent->ne [2 ], 1 );
1479
1486
// no mask, set the whole image as masked
@@ -1487,6 +1494,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1487
1494
for (int64_t c = init_latent->ne [2 ]; c < empty_latent->ne [2 ]; c++) {
1488
1495
ggml_tensor_set_f32 (empty_latent, 1 , x, y, c);
1489
1496
}
1497
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1498
+ for (int64_t c = 0 ; c < empty_latent->ne [2 ]; c++) {
1499
+ // 0x16,1x1,0x16
1500
+ ggml_tensor_set_f32 (empty_latent, c == init_latent->ne [2 ], x, y, c);
1501
+ }
1490
1502
} else {
1491
1503
ggml_tensor_set_f32 (empty_latent, 1 , x, y, 0 );
1492
1504
for (int64_t c = 1 ; c < empty_latent->ne [2 ]; c++) {
@@ -1495,7 +1507,36 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1495
1507
}
1496
1508
}
1497
1509
}
1498
- if (concat_latent == NULL ) {
1510
+ if (sd_ctx->sd ->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd ->control_net == NULL ) {
1511
+ bool no_inpaint = concat_latent == NULL ;
1512
+ if (no_inpaint) {
1513
+ concat_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, init_latent->ne [0 ], init_latent->ne [1 ], mask_channels + init_latent->ne [2 ], 1 );
1514
+ }
1515
+ // fill in the control image here
1516
+ struct ggml_tensor * control_latents = NULL ;
1517
+ if (!sd_ctx->sd ->use_tiny_autoencoder ) {
1518
+ struct ggml_tensor * control_moments = sd_ctx->sd ->encode_first_stage (work_ctx, image_hint);
1519
+ control_latents = sd_ctx->sd ->get_first_stage_encoding (work_ctx, control_moments);
1520
+ } else {
1521
+ control_latents = sd_ctx->sd ->encode_first_stage (work_ctx, image_hint);
1522
+ }
1523
+ for (int64_t x = 0 ; x < concat_latent->ne [0 ]; x++) {
1524
+ for (int64_t y = 0 ; y < concat_latent->ne [1 ]; y++) {
1525
+ if (no_inpaint) {
1526
+ for (int64_t c = 0 ; c < concat_latent->ne [2 ] - control_latents->ne [2 ]; c++) {
1527
+ // 0x16,1x1,0x16
1528
+ ggml_tensor_set_f32 (concat_latent, c == init_latent->ne [2 ], x, y, c);
1529
+ }
1530
+ }
1531
+ for (int64_t c = 0 ; c < control_latents->ne [2 ]; c++) {
1532
+ float v = ggml_tensor_get_f32 (control_latents, x, y, c);
1533
+ ggml_tensor_set_f32 (concat_latent, v, x, y, concat_latent->ne [2 ] - control_latents->ne [2 ] + c);
1534
+ }
1535
+ }
1536
+ }
1537
+ // Disable controlnet
1538
+ image_hint = NULL ;
1539
+ } else if (concat_latent == NULL ) {
1499
1540
concat_latent = empty_latent;
1500
1541
}
1501
1542
cond.c_concat = concat_latent;
@@ -1772,6 +1813,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
1772
1813
int64_t mask_channels = 1 ;
1773
1814
if (sd_ctx->sd ->version == VERSION_FLUX_FILL) {
1774
1815
mask_channels = 8 * 8 ; // flatten the whole mask
1816
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1817
+ mask_channels = 1 + init_latent->ne [2 ];
1775
1818
}
1776
1819
ggml_tensor* masked_img = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, width, height, 3 , 1 );
1777
1820
// Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
@@ -1803,6 +1846,32 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
1803
1846
ggml_tensor_set_f32 (concat_latent, m, ix, iy, masked_latent->ne [2 ] + x * 8 + y);
1804
1847
}
1805
1848
}
1849
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1850
+ float m = ggml_tensor_get_f32 (mask_img, mx, my);
1851
+ // masked image
1852
+ for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
1853
+ float v = ggml_tensor_get_f32 (masked_latent, ix, iy, k);
1854
+ ggml_tensor_set_f32 (concat_latent, v, ix, iy, k);
1855
+ }
1856
+ // downsampled mask
1857
+ ggml_tensor_set_f32 (concat_latent, m, ix, iy, masked_latent->ne [2 ]);
1858
+ // control (todo: support this)
1859
+ for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
1860
+ ggml_tensor_set_f32 (concat_latent, 0 , ix, iy, masked_latent->ne [2 ] + 1 + k);
1861
+ }
1862
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1863
+ float m = ggml_tensor_get_f32 (mask_img, mx, my);
1864
+ // masked image
1865
+ for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
1866
+ float v = ggml_tensor_get_f32 (masked_latent, ix, iy, k);
1867
+ ggml_tensor_set_f32 (concat_latent, v, ix, iy, k);
1868
+ }
1869
+ // downsampled mask
1870
+ ggml_tensor_set_f32 (concat_latent, m, ix, iy, masked_latent->ne [2 ]);
1871
+ // control (todo: support this)
1872
+ for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
1873
+ ggml_tensor_set_f32 (concat_latent, 0 , ix, iy, masked_latent->ne [2 ] + 1 + k);
1874
+ }
1806
1875
} else {
1807
1876
float m = ggml_tensor_get_f32 (mask_img, mx, my);
1808
1877
ggml_tensor_set_f32 (concat_latent, m, ix, iy, 0 );
0 commit comments