Skip to content

Commit 8004747

Browse files
committed
Added E2, F5 model types.
#9
1 parent 88d4ca8 commit 8004747

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

F5TTS.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from comfy.utils import ProgressBar
1818
from cached_path import cached_path
1919
sys.path.append(Install.f5TTSPath)
20-
from model import DiT # noqa E402
20+
from model import DiT,UNetT # noqa E402
2121
from model.utils_infer import ( # noqa E402
2222
load_model,
2323
preprocess_ref_audio_text,
@@ -28,6 +28,7 @@
2828

2929
class F5TTSCreate:
3030
voice_reg = re.compile(r"\{(\w+)\}")
31+
model_types = ["F5", "E2"]
3132
tooltip_seed = "Seed. -1 = random"
3233

3334
def is_voice_name(self, word):
@@ -54,7 +55,33 @@ def load_voice(ref_audio, ref_text):
5455
)
5556
return main_voice
5657

57-
def load_model(self):
58+
def load_model(self, model):
59+
models = {
60+
"F5": self.load_f5_model,
61+
"E2": self.load_e2_model,
62+
}
63+
return models[model]()
64+
65+
def get_vocab_file(self):
66+
return os.path.join(
67+
Install.f5TTSPath, "data/Emilia_ZH_EN_pinyin/vocab.txt"
68+
)
69+
70+
def load_e2_model(self):
71+
model_cls = UNetT
72+
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
73+
repo_name = "E2-TTS"
74+
exp_name = "E2TTS_Base"
75+
ckpt_step = 1200000
76+
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) # noqa E501
77+
vocab_file = self.get_vocab_file()
78+
ema_model = load_model(
79+
model_cls, model_cfg,
80+
ckpt_file, vocab_file
81+
)
82+
return ema_model
83+
84+
def load_f5_model(self):
5885
model_cls = DiT
5986
model_cfg = dict(
6087
dim=1024, depth=22, heads=16,
@@ -64,10 +91,11 @@ def load_model(self):
6491
exp_name = "F5TTS_Base"
6592
ckpt_step = 1200000
6693
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) # noqa E501
67-
vocab_file = os.path.join(
68-
Install.f5TTSPath, "data/Emilia_ZH_EN_pinyin/vocab.txt"
94+
vocab_file = self.get_vocab_file()
95+
ema_model = load_model(
96+
model_cls, model_cfg,
97+
ckpt_file, vocab_file
6998
)
70-
ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
7199
return ema_model
72100

73101
def generate_audio(self, voices, model_obj, chunks, seed):
@@ -117,8 +145,8 @@ def generate_audio(self, voices, model_obj, chunks, seed):
117145
os.unlink(wave_file.name)
118146
return audio
119147

120-
def create(self, voices, chunks, seed=-1):
121-
model_obj = self.load_model()
148+
def create(self, voices, chunks, seed=-1, model="F5"):
149+
model_obj = self.load_model(model)
122150
return self.generate_audio(voices, model_obj, chunks, seed)
123151

124152

@@ -141,6 +169,7 @@ def INPUT_TYPES(s):
141169
"default": 1, "min": -1,
142170
"tooltip": F5TTSCreate.tooltip_seed,
143171
}),
172+
"model": (F5TTSCreate.model_types,),
144173
},
145174
}
146175

@@ -174,7 +203,7 @@ def remove_wave_file(self):
174203
print("F5TTS: Cannot remove? "+self.wave_file.name)
175204
print(e)
176205

177-
def create(self, sample_audio, sample_text, speech, seed=-1):
206+
def create(self, sample_audio, sample_text, speech, seed=-1, model="F5"):
178207
try:
179208
main_voice = self.load_voice_from_input(sample_audio, sample_text)
180209

@@ -184,7 +213,7 @@ def create(self, sample_audio, sample_text, speech, seed=-1):
184213
chunks = f5ttsCreate.split_text(speech)
185214
voices['main'] = main_voice
186215

187-
audio = f5ttsCreate.create(voices, chunks, seed)
216+
audio = f5ttsCreate.create(voices, chunks, seed, model)
188217
finally:
189218
self.remove_wave_file()
190219
return (audio, )
@@ -233,6 +262,7 @@ def INPUT_TYPES(s):
233262
"default": 1, "min": -1,
234263
"tooltip": F5TTSCreate.tooltip_seed,
235264
}),
265+
"model": (F5TTSCreate.model_types,),
236266
}
237267
}
238268

@@ -289,7 +319,7 @@ def load_voices_from_files(self, sample, voice_names):
289319
voices[voice_name] = self.load_voice_from_file(sample_file)
290320
return voices
291321

292-
def create(self, sample, speech, seed=-1):
322+
def create(self, sample, speech, seed=-1, model="F5"):
293323
# Install.check_install()
294324
main_voice = self.load_voice_from_file(sample)
295325

@@ -309,7 +339,7 @@ def create(self, sample, speech, seed=-1):
309339
voices = self.load_voices_from_files(sample, voice_names)
310340
voices['main'] = main_voice
311341

312-
audio = f5ttsCreate.create(voices, chunks, seed)
342+
audio = f5ttsCreate.create(voices, chunks, seed, model)
313343
return (audio, )
314344

315345
@classmethod

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-f5-tts"
33
description = "Text to speech with F5-TTS"
4-
version = "1.0.4"
4+
version = "1.0.5"
55
license = {text = "MIT License"}
66

77
[project.urls]

0 commit comments

Comments
 (0)