Skip to content

Commit dafc32d

Browse files
authored
feat: add support for f64/i64 and clip_g diffusers model (#681)
1 parent 225162f commit dafc32d

File tree

1 file changed

+72
-9
lines changed

1 file changed

+72
-9
lines changed

model.cpp

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,10 @@ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> su
338338
{"to_v", "v"},
339339
{"to_out_0", "proj_out"},
340340
{"group_norm", "norm"},
341+
{"key", "k"},
342+
{"query", "q"},
343+
{"value", "v"},
344+
{"proj_attn", "proj_out"},
341345
},
342346
},
343347
{
@@ -362,6 +366,10 @@ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> su
362366
{"to_v", "v"},
363367
{"to_out.0", "proj_out"},
364368
{"group_norm", "norm"},
369+
{"key", "k"},
370+
{"query", "q"},
371+
{"value", "v"},
372+
{"proj_attn", "proj_out"},
365373
},
366374
},
367375
{
@@ -433,6 +441,10 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
433441
return format("model%cdiffusion_model%ctime_embed%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
434442
}
435443

444+
if (match(m, std::regex(format("unet%cadd_embedding%clinear_(\\d+)(.*)", seq, seq)), key)) {
445+
return format("model%cdiffusion_model%clabel_emb%c0%c", seq, seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
446+
}
447+
436448
if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
437449
std::string suffix = get_converted_suffix(m[1], m[3]);
438450
// LOG_DEBUG("%s %s %s %s", m[0].c_str(), m[1].c_str(), m[2].c_str(), m[3].c_str());
@@ -470,6 +482,19 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
470482
return format("cond_stage_model%ctransformer%ctext_model", seq, seq) + m[0];
471483
}
472484

485+
// clip-g
486+
if (match(m, std::regex(format("te%c1%ctext_model%cencoder%clayers%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) {
487+
return format("cond_stage_model%c1%ctransformer%ctext_model%cencoder%clayers%c", seq, seq, seq, seq, seq, seq) + m[0] + seq + m[1];
488+
}
489+
490+
if (match(m, std::regex(format("te%c1%ctext_model(.*)", seq, seq)), key)) {
491+
return format("cond_stage_model%c1%ctransformer%ctext_model", seq, seq, seq) + m[0];
492+
}
493+
494+
if (match(m, std::regex(format("te%c1%ctext_projection", seq, seq)), key)) {
495+
return format("cond_stage_model%c1%ctransformer%ctext_model%ctext_projection", seq, seq, seq, seq);
496+
}
497+
473498
// vae
474499
if (match(m, std::regex(format("vae%c(.*)%cconv_norm_out(.*)", seq, seq)), key)) {
475500
return format("first_stage_model%c%s%cnorm_out%s", seq, m[0].c_str(), seq, m[1].c_str());
@@ -606,6 +631,8 @@ std::string convert_tensor_name(std::string name) {
606631
std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.');
607632
if (new_key.empty()) {
608633
new_name = name;
634+
} else if (new_key == "cond_stage_model.1.transformer.text_model.text_projection") {
635+
new_name = new_key;
609636
} else {
610637
new_name = new_key + "." + network_part;
611638
}
@@ -1029,10 +1056,14 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
10291056
ttype = GGML_TYPE_F32;
10301057
} else if (dtype == "F32") {
10311058
ttype = GGML_TYPE_F32;
1059+
} else if (dtype == "F64") {
1060+
ttype = GGML_TYPE_F64;
10321061
} else if (dtype == "F8_E4M3") {
10331062
ttype = GGML_TYPE_F16;
10341063
} else if (dtype == "F8_E5M2") {
10351064
ttype = GGML_TYPE_F16;
1065+
} else if (dtype == "I64") {
1066+
ttype = GGML_TYPE_I64;
10361067
}
10371068
return ttype;
10381069
}
@@ -1045,6 +1076,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10451076
std::ifstream file(file_path, std::ios::binary);
10461077
if (!file.is_open()) {
10471078
LOG_ERROR("failed to open '%s'", file_path.c_str());
1079+
file_paths_.pop_back();
10481080
return false;
10491081
}
10501082

@@ -1056,6 +1088,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10561088
// read header size
10571089
if (file_size_ <= ST_HEADER_SIZE_LEN) {
10581090
LOG_ERROR("invalid safetensor file '%s'", file_path.c_str());
1091+
file_paths_.pop_back();
10591092
return false;
10601093
}
10611094

@@ -1069,6 +1102,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10691102
size_t header_size_ = read_u64(header_size_buf);
10701103
if (header_size_ >= file_size_) {
10711104
LOG_ERROR("invalid safetensor file '%s'", file_path.c_str());
1105+
file_paths_.pop_back();
10721106
return false;
10731107
}
10741108

@@ -1079,6 +1113,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10791113
file.read(header_buf.data(), header_size_);
10801114
if (!file) {
10811115
LOG_ERROR("read safetensors header failed: '%s'", file_path.c_str());
1116+
file_paths_.pop_back();
10821117
return false;
10831118
}
10841119

@@ -1134,6 +1169,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
11341169
n_dims = 1;
11351170
}
11361171

1172+
11371173
TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
11381174
tensor_storage.reverse_ne();
11391175

@@ -1166,18 +1202,45 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
11661202
/*================================================= DiffusersModelLoader ==================================================*/
11671203

11681204
bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) {
1169-
std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
1170-
std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
1171-
std::string clip_path = path_join(file_path, "text_encoder/model.safetensors");
1205+
std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
1206+
std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
1207+
std::string clip_path = path_join(file_path, "text_encoder/model.safetensors");
1208+
std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors");
11721209

11731210
if (!init_from_safetensors_file(unet_path, "unet.")) {
11741211
return false;
11751212
}
1213+
for (auto ts : tensor_storages) {
1214+
if (ts.name.find("add_embedding") != std::string::npos || ts.name.find("label_emb") != std::string::npos) {
1215+
// probably SDXL
1216+
LOG_DEBUG("Fixing name for SDXL output blocks.2.2");
1217+
for (auto& tensor_storage : tensor_storages) {
1218+
int len = 34;
1219+
auto pos = tensor_storage.name.find("unet.up_blocks.0.upsamplers.0.conv");
1220+
if (pos == std::string::npos) {
1221+
len = 44;
1222+
pos = tensor_storage.name.find("model.diffusion_model.output_blocks.2.1.conv");
1223+
}
1224+
if (pos != std::string::npos) {
1225+
tensor_storage.name = "model.diffusion_model.output_blocks.2.2.conv" + tensor_storage.name.substr(len);
1226+
LOG_DEBUG("NEW NAME: %s", tensor_storage.name.c_str());
1227+
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
1228+
}
1229+
}
1230+
break;
1231+
}
1232+
}
1233+
11761234
if (!init_from_safetensors_file(vae_path, "vae.")) {
1177-
return false;
1235+
LOG_WARN("Couldn't find working VAE in %s", file_path.c_str());
1236+
// return false;
11781237
}
11791238
if (!init_from_safetensors_file(clip_path, "te.")) {
1180-
return false;
1239+
LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str());
1240+
// return false;
1241+
}
1242+
if (!init_from_safetensors_file(clip_g_path, "te.1.")) {
1243+
LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str());
11811244
}
11821245
return true;
11831246
}
@@ -1571,7 +1634,7 @@ SDVersion ModelLoader::get_sd_version() {
15711634
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
15721635
return VERSION_SD3;
15731636
}
1574-
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) {
1637+
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
15751638
is_unet = true;
15761639
if (has_multiple_encoders) {
15771640
is_xl = true;
@@ -1580,7 +1643,7 @@ SDVersion ModelLoader::get_sd_version() {
15801643
}
15811644
}
15821645
}
1583-
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
1646+
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos || tensor_storage.name.find("te.1") != std::string::npos) {
15841647
has_multiple_encoders = true;
15851648
if (is_unet) {
15861649
is_xl = true;
@@ -1602,7 +1665,7 @@ SDVersion ModelLoader::get_sd_version() {
16021665
token_embedding_weight = tensor_storage;
16031666
// break;
16041667
}
1605-
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight") {
1668+
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
16061669
input_block_weight = tensor_storage;
16071670
input_block_checked = true;
16081671
if (found_family) {
@@ -1687,7 +1750,7 @@ ggml_type ModelLoader::get_diffusion_model_wtype() {
16871750
continue;
16881751
}
16891752

1690-
if (tensor_storage.name.find("model.diffusion_model.") == std::string::npos) {
1753+
if (tensor_storage.name.find("model.diffusion_model.") == std::string::npos && tensor_storage.name.find("unet.") == std::string::npos) {
16911754
continue;
16921755
}
16931756

0 commit comments

Comments
 (0)