@@ -338,6 +338,10 @@ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> su
338
338
{" to_v" , " v" },
339
339
{" to_out_0" , " proj_out" },
340
340
{" group_norm" , " norm" },
341
+ {" key" , " k" },
342
+ {" query" , " q" },
343
+ {" value" , " v" },
344
+ {" proj_attn" , " proj_out" },
341
345
},
342
346
},
343
347
{
@@ -362,6 +366,10 @@ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> su
362
366
{" to_v" , " v" },
363
367
{" to_out.0" , " proj_out" },
364
368
{" group_norm" , " norm" },
369
+ {" key" , " k" },
370
+ {" query" , " q" },
371
+ {" value" , " v" },
372
+ {" proj_attn" , " proj_out" },
365
373
},
366
374
},
367
375
{
@@ -433,6 +441,10 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
433
441
return format (" model%cdiffusion_model%ctime_embed%c" , seq, seq, seq) + std::to_string (std::stoi (m[0 ]) * 2 - 2 ) + m[1 ];
434
442
}
435
443
444
+ if (match (m, std::regex (format (" unet%cadd_embedding%clinear_(\\ d+)(.*)" , seq, seq)), key)) {
445
+ return format (" model%cdiffusion_model%clabel_emb%c0%c" , seq, seq, seq, seq) + std::to_string (std::stoi (m[0 ]) * 2 - 2 ) + m[1 ];
446
+ }
447
+
436
448
if (match (m, std::regex (format (" unet%cdown_blocks%c(\\ d+)%c(attentions|resnets)%c(\\ d+)%c(.+)" , seq, seq, seq, seq, seq)), key)) {
437
449
std::string suffix = get_converted_suffix (m[1 ], m[3 ]);
438
450
// LOG_DEBUG("%s %s %s %s", m[0].c_str(), m[1].c_str(), m[2].c_str(), m[3].c_str());
@@ -470,6 +482,19 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
470
482
return format (" cond_stage_model%ctransformer%ctext_model" , seq, seq) + m[0 ];
471
483
}
472
484
485
+ // clip-g
486
+ if (match (m, std::regex (format (" te%c1%ctext_model%cencoder%clayers%c(\\ d+)%c(.+)" , seq, seq, seq, seq, seq, seq)), key)) {
487
+ return format (" cond_stage_model%c1%ctransformer%ctext_model%cencoder%clayers%c" , seq, seq, seq, seq, seq, seq) + m[0 ] + seq + m[1 ];
488
+ }
489
+
490
+ if (match (m, std::regex (format (" te%c1%ctext_model(.*)" , seq, seq)), key)) {
491
+ return format (" cond_stage_model%c1%ctransformer%ctext_model" , seq, seq, seq) + m[0 ];
492
+ }
493
+
494
+ if (match (m, std::regex (format (" te%c1%ctext_projection" , seq, seq)), key)) {
495
+ return format (" cond_stage_model%c1%ctransformer%ctext_model%ctext_projection" , seq, seq, seq, seq);
496
+ }
497
+
473
498
// vae
474
499
if (match (m, std::regex (format (" vae%c(.*)%cconv_norm_out(.*)" , seq, seq)), key)) {
475
500
return format (" first_stage_model%c%s%cnorm_out%s" , seq, m[0 ].c_str (), seq, m[1 ].c_str ());
@@ -606,6 +631,8 @@ std::string convert_tensor_name(std::string name) {
606
631
std::string new_key = convert_diffusers_name_to_compvis (name_without_network_parts, ' .' );
607
632
if (new_key.empty ()) {
608
633
new_name = name;
634
+ } else if (new_key == " cond_stage_model.1.transformer.text_model.text_projection" ) {
635
+ new_name = new_key;
609
636
} else {
610
637
new_name = new_key + " ." + network_part;
611
638
}
@@ -1029,10 +1056,14 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
1029
1056
ttype = GGML_TYPE_F32;
1030
1057
} else if (dtype == " F32" ) {
1031
1058
ttype = GGML_TYPE_F32;
1059
+ } else if (dtype == " F64" ) {
1060
+ ttype = GGML_TYPE_F64;
1032
1061
} else if (dtype == " F8_E4M3" ) {
1033
1062
ttype = GGML_TYPE_F16;
1034
1063
} else if (dtype == " F8_E5M2" ) {
1035
1064
ttype = GGML_TYPE_F16;
1065
+ } else if (dtype == " I64" ) {
1066
+ ttype = GGML_TYPE_I64;
1036
1067
}
1037
1068
return ttype;
1038
1069
}
@@ -1045,6 +1076,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
1045
1076
std::ifstream file (file_path, std::ios::binary);
1046
1077
if (!file.is_open ()) {
1047
1078
LOG_ERROR (" failed to open '%s'" , file_path.c_str ());
1079
+ file_paths_.pop_back ();
1048
1080
return false ;
1049
1081
}
1050
1082
@@ -1056,6 +1088,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
1056
1088
// read header size
1057
1089
if (file_size_ <= ST_HEADER_SIZE_LEN) {
1058
1090
LOG_ERROR (" invalid safetensor file '%s'" , file_path.c_str ());
1091
+ file_paths_.pop_back ();
1059
1092
return false ;
1060
1093
}
1061
1094
@@ -1069,6 +1102,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
1069
1102
size_t header_size_ = read_u64 (header_size_buf);
1070
1103
if (header_size_ >= file_size_) {
1071
1104
LOG_ERROR (" invalid safetensor file '%s'" , file_path.c_str ());
1105
+ file_paths_.pop_back ();
1072
1106
return false ;
1073
1107
}
1074
1108
@@ -1079,6 +1113,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
1079
1113
file.read (header_buf.data (), header_size_);
1080
1114
if (!file) {
1081
1115
LOG_ERROR (" read safetensors header failed: '%s'" , file_path.c_str ());
1116
+ file_paths_.pop_back ();
1082
1117
return false ;
1083
1118
}
1084
1119
@@ -1134,6 +1169,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
1134
1169
n_dims = 1 ;
1135
1170
}
1136
1171
1172
+
1137
1173
TensorStorage tensor_storage (prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
1138
1174
tensor_storage.reverse_ne ();
1139
1175
@@ -1166,18 +1202,45 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
1166
1202
/* ================================================= DiffusersModelLoader ==================================================*/
1167
1203
1168
1204
bool ModelLoader::init_from_diffusers_file (const std::string& file_path, const std::string& prefix) {
1169
- std::string unet_path = path_join (file_path, " unet/diffusion_pytorch_model.safetensors" );
1170
- std::string vae_path = path_join (file_path, " vae/diffusion_pytorch_model.safetensors" );
1171
- std::string clip_path = path_join (file_path, " text_encoder/model.safetensors" );
1205
+ std::string unet_path = path_join (file_path, " unet/diffusion_pytorch_model.safetensors" );
1206
+ std::string vae_path = path_join (file_path, " vae/diffusion_pytorch_model.safetensors" );
1207
+ std::string clip_path = path_join (file_path, " text_encoder/model.safetensors" );
1208
+ std::string clip_g_path = path_join (file_path, " text_encoder_2/model.safetensors" );
1172
1209
1173
1210
if (!init_from_safetensors_file (unet_path, " unet." )) {
1174
1211
return false ;
1175
1212
}
1213
+ for (auto ts : tensor_storages) {
1214
+ if (ts.name .find (" add_embedding" ) != std::string::npos || ts.name .find (" label_emb" ) != std::string::npos) {
1215
+ // probably SDXL
1216
+ LOG_DEBUG (" Fixing name for SDXL output blocks.2.2" );
1217
+ for (auto & tensor_storage : tensor_storages) {
1218
+ int len = 34 ;
1219
+ auto pos = tensor_storage.name .find (" unet.up_blocks.0.upsamplers.0.conv" );
1220
+ if (pos == std::string::npos) {
1221
+ len = 44 ;
1222
+ pos = tensor_storage.name .find (" model.diffusion_model.output_blocks.2.1.conv" );
1223
+ }
1224
+ if (pos != std::string::npos) {
1225
+ tensor_storage.name = " model.diffusion_model.output_blocks.2.2.conv" + tensor_storage.name .substr (len);
1226
+ LOG_DEBUG (" NEW NAME: %s" , tensor_storage.name .c_str ());
1227
+ add_preprocess_tensor_storage_types (tensor_storages_types, tensor_storage.name , tensor_storage.type );
1228
+ }
1229
+ }
1230
+ break ;
1231
+ }
1232
+ }
1233
+
1176
1234
if (!init_from_safetensors_file (vae_path, " vae." )) {
1177
- return false ;
1235
+ LOG_WARN (" Couldn't find working VAE in %s" , file_path.c_str ());
1236
+ // return false;
1178
1237
}
1179
1238
if (!init_from_safetensors_file (clip_path, " te." )) {
1180
- return false ;
1239
+ LOG_WARN (" Couldn't find working text encoder in %s" , file_path.c_str ());
1240
+ // return false;
1241
+ }
1242
+ if (!init_from_safetensors_file (clip_g_path, " te.1." )) {
1243
+ LOG_DEBUG (" Couldn't find working second text encoder in %s" , file_path.c_str ());
1181
1244
}
1182
1245
return true ;
1183
1246
}
@@ -1571,7 +1634,7 @@ SDVersion ModelLoader::get_sd_version() {
1571
1634
if (tensor_storage.name .find (" model.diffusion_model.joint_blocks." ) != std::string::npos) {
1572
1635
return VERSION_SD3;
1573
1636
}
1574
- if (tensor_storage.name .find (" model.diffusion_model.input_blocks." ) != std::string::npos) {
1637
+ if (tensor_storage.name .find (" model.diffusion_model.input_blocks." ) != std::string::npos || tensor_storage. name . find ( " unet.down_blocks. " ) != std::string::npos ) {
1575
1638
is_unet = true ;
1576
1639
if (has_multiple_encoders) {
1577
1640
is_xl = true ;
@@ -1580,7 +1643,7 @@ SDVersion ModelLoader::get_sd_version() {
1580
1643
}
1581
1644
}
1582
1645
}
1583
- if (tensor_storage.name .find (" conditioner.embedders.1" ) != std::string::npos || tensor_storage.name .find (" cond_stage_model.1" ) != std::string::npos) {
1646
+ if (tensor_storage.name .find (" conditioner.embedders.1" ) != std::string::npos || tensor_storage.name .find (" cond_stage_model.1" ) != std::string::npos || tensor_storage. name . find ( " te.1 " ) != std::string::npos ) {
1584
1647
has_multiple_encoders = true ;
1585
1648
if (is_unet) {
1586
1649
is_xl = true ;
@@ -1602,7 +1665,7 @@ SDVersion ModelLoader::get_sd_version() {
1602
1665
token_embedding_weight = tensor_storage;
1603
1666
// break;
1604
1667
}
1605
- if (tensor_storage.name == " model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == " model.diffusion_model.img_in.weight" ) {
1668
+ if (tensor_storage.name == " model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == " model.diffusion_model.img_in.weight" || tensor_storage. name == " unet.conv_in.weight " ) {
1606
1669
input_block_weight = tensor_storage;
1607
1670
input_block_checked = true ;
1608
1671
if (found_family) {
@@ -1687,7 +1750,7 @@ ggml_type ModelLoader::get_diffusion_model_wtype() {
1687
1750
continue ;
1688
1751
}
1689
1752
1690
- if (tensor_storage.name .find (" model.diffusion_model." ) == std::string::npos) {
1753
+ if (tensor_storage.name .find (" model.diffusion_model." ) == std::string::npos && tensor_storage. name . find ( " unet. " ) == std::string::npos ) {
1691
1754
continue ;
1692
1755
}
1693
1756
0 commit comments