Skip to content

Commit 3bca21f

Browse files
committed
Updated to work with latest version of F5-TTS#6a3659dbf8332a08819e249755517377b52d0bc1
1 parent d39b0c5 commit 3bca21f

File tree

4 files changed

+157
-41
lines changed

4 files changed

+157
-41
lines changed

F5-TTS

Submodule F5-TTS updated 73 files

F5TTS.py

Lines changed: 155 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
import io
1717
from comfy.utils import ProgressBar
1818
from cached_path import cached_path
19-
sys.path.append(Install.f5TTSPath)
20-
from model import DiT,UNetT # noqa E402
21-
from model.utils_infer import ( # noqa E402
19+
sys.path.append(os.path.join(Install.f5TTSPath, "src"))
20+
from f5_tts.model import DiT,UNetT # noqa E402
21+
from f5_tts.infer.utils_infer import ( # noqa E402
2222
load_model,
23+
load_vocoder,
2324
preprocess_ref_audio_text,
2425
infer_process,
2526
)
@@ -28,9 +29,37 @@
2829

2930
class F5TTSCreate:
3031
voice_reg = re.compile(r"\{([^\}]+)\}")
31-
model_types = ["F5", "E2"]
32+
model_types = ["F5", "F5-JP", "F5-FR", "E2"]
33+
vocoder_types = ["vocos", "bigvgan"]
3234
tooltip_seed = "Seed. -1 = random"
3335

36+
def get_model_types():
37+
model_types = F5TTSCreate.model_types[:]
38+
models_path = folder_paths.get_folder_paths("checkpoints")
39+
for model_path in models_path:
40+
f5_model_path = os.path.join(model_path, 'F5-TTS')
41+
if os.path.isdir(f5_model_path):
42+
for file in os.listdir(f5_model_path):
43+
p = Path(file)
44+
if (
45+
p.suffix in folder_paths.supported_pt_extensions
46+
and os.path.isfile(os.path.join(f5_model_path, file))
47+
):
48+
txtFile = F5TTSCreate.get_txt_file_path(
49+
os.path.join(f5_model_path, file)
50+
)
51+
52+
if (
53+
os.path.isfile(txtFile)
54+
):
55+
model_types.append("model://"+file)
56+
return model_types
57+
58+
@staticmethod
59+
def get_txt_file_path(file):
60+
p = Path(file)
61+
return os.path.join(os.path.dirname(file), p.stem + ".txt")
62+
3463
def is_voice_name(self, word):
3564
return self.voice_reg.match(word.strip())
3665

@@ -55,50 +84,118 @@ def load_voice(ref_audio, ref_text):
5584
)
5685
return main_voice
5786

58-
def load_model(self, model):
59-
models = {
87+
def get_model_funcs(self):
88+
return {
6089
"F5": self.load_f5_model,
90+
"F5-JP": self.load_f5_model_jp,
91+
"F5-FR": self.load_f5_model_fr,
6192
"E2": self.load_e2_model,
6293
}
63-
return models[model]()
94+
95+
def get_vocoder(self, vocoder_name):
96+
if vocoder_name == "vocos":
97+
os.path.join(Install.f5TTSPath, "checkpoints/vocos-mel-24khz")
98+
elif vocoder_name == "bigvgan":
99+
os.path.join(Install.f5TTSPath, "checkpoints/bigvgan_v2_24khz_100band_256x") # noqa E501
100+
101+
def load_vocoder(self, vocoder_name):
102+
return load_vocoder(vocoder_name=vocoder_name)
103+
104+
def load_model(self, model, vocoder_name):
105+
model_funcs = self.get_model_funcs()
106+
if model in model_funcs:
107+
return model_funcs[model](vocoder_name)
108+
else:
109+
return self.load_f5_model_url(model, vocoder_name)
64110

65111
def get_vocab_file(self):
66112
return os.path.join(
67113
Install.f5TTSPath, "data/Emilia_ZH_EN_pinyin/vocab.txt"
68114
)
69115

70-
def load_e2_model(self):
116+
def load_e2_model(self, vocoder):
71117
model_cls = UNetT
72118
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
73119
repo_name = "E2-TTS"
74120
exp_name = "E2TTS_Base"
75121
ckpt_step = 1200000
76122
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) # noqa E501
77123
vocab_file = self.get_vocab_file()
124+
vocoder_name = "vocos"
78125
ema_model = load_model(
79126
model_cls, model_cfg,
80-
ckpt_file, vocab_file
127+
ckpt_file, vocab_file=vocab_file,
128+
mel_spec_type=vocoder_name,
129+
)
130+
vocoder = self.load_vocoder(vocoder_name)
131+
return (ema_model, vocoder, vocoder_name)
132+
133+
def load_f5_model(self, vocoder):
134+
repo_name = "F5-TTS"
135+
if vocoder == "bigvgan":
136+
exp_name = "F5TTS_Base_bigvgan"
137+
ckpt_step = 1250000
138+
else:
139+
exp_name = "F5TTS_Base"
140+
ckpt_step = 1200000
141+
return self.load_f5_model_url(
142+
f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors", # noqa E501
143+
vocoder,
144+
)
145+
146+
def load_f5_model_jp(self, vocoder):
147+
return self.load_f5_model_url(
148+
"hf://Jmica/F5TTS/JA_8500000/model_8499660.pt",
149+
vocoder,
150+
"hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt"
81151
)
82-
return ema_model
83152

84-
def load_f5_model(self):
153+
def load_f5_model_fr(self, vocoder):
154+
return self.load_f5_model_url(
155+
"hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_1374000.pt", # noqa E501
156+
vocoder,
157+
"hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt" # noqa E501
158+
)
159+
160+
def cached_path(self, url):
161+
if url.startswith("model:"):
162+
path = re.sub("^model:/*", "", url)
163+
models_path = folder_paths.get_folder_paths("checkpoints")
164+
for model_path in models_path:
165+
f5_model_path = os.path.join(model_path, 'F5-TTS')
166+
model_file = os.path.join(f5_model_path, path)
167+
if os.path.isfile(model_file):
168+
return model_file
169+
raise FileNotFoundError("No model found: " + url)
170+
return None
171+
return str(cached_path(url)) # noqa E501
172+
173+
def load_f5_model_url(self, url, vocoder_name, vocab_url=None):
174+
vocoder = self.load_vocoder(vocoder_name)
85175
model_cls = DiT
86176
model_cfg = dict(
87177
dim=1024, depth=22, heads=16,
88178
ff_mult=2, text_dim=512, conv_layers=4
89179
)
90-
repo_name = "F5-TTS"
91-
exp_name = "F5TTS_Base"
92-
ckpt_step = 1200000
93-
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) # noqa E501
94-
vocab_file = self.get_vocab_file()
180+
ckpt_file = str(self.cached_path(url)) # noqa E501
181+
182+
if vocab_url is None:
183+
if url.startswith("model:"):
184+
vocab_file = F5TTSCreate.get_txt_file_path(ckpt_file)
185+
else:
186+
vocab_file = self.get_vocab_file()
187+
else:
188+
vocab_file = str(self.cached_path(vocab_url))
95189
ema_model = load_model(
96190
model_cls, model_cfg,
97-
ckpt_file, vocab_file
191+
ckpt_file, vocab_file=vocab_file,
192+
mel_spec_type=vocoder_name,
98193
)
99-
return ema_model
194+
return (ema_model, vocoder, vocoder_name)
100195

101-
def generate_audio(self, voices, model_obj, chunks, seed):
196+
def generate_audio(
197+
self, voices, model_obj, chunks, seed, vocoder, mel_spec_type
198+
):
102199
if seed >= 0:
103200
torch.manual_seed(seed)
104201
else:
@@ -127,7 +224,8 @@ def generate_audio(self, voices, model_obj, chunks, seed):
127224
print(f"Voice: {voice}")
128225
print("text:"+text)
129226
audio, final_sample_rate, spectragram = infer_process(
130-
ref_audio, ref_text, gen_text, model_obj
227+
ref_audio, ref_text, gen_text, model_obj,
228+
vocoder=vocoder, mel_spec_type=mel_spec_type
131229
)
132230
generated_audio_segments.append(audio)
133231
frame_rate = final_sample_rate
@@ -147,9 +245,20 @@ def generate_audio(self, voices, model_obj, chunks, seed):
147245
os.unlink(wave_file.name)
148246
return audio
149247

150-
def create(self, voices, chunks, seed=-1, model="F5"):
151-
model_obj = self.load_model(model)
152-
return self.generate_audio(voices, model_obj, chunks, seed)
248+
def create(
249+
self, voices, chunks, seed=-1, model="F5", vocoder_name="vocos"
250+
):
251+
(
252+
model_obj,
253+
vocoder,
254+
mel_spec_type
255+
) = self.load_model(model, vocoder_name)
256+
return self.generate_audio(
257+
voices,
258+
model_obj,
259+
chunks, seed,
260+
vocoder, mel_spec_type=mel_spec_type,
261+
)
153262

154263

155264
class F5TTSAudioInputs:
@@ -158,6 +267,7 @@ def __init__(self):
158267

159268
@classmethod
160269
def INPUT_TYPES(s):
270+
model_types = F5TTSCreate.get_model_types()
161271
return {
162272
"required": {
163273
"sample_audio": ("AUDIO",),
@@ -171,7 +281,8 @@ def INPUT_TYPES(s):
171281
"default": 1, "min": -1,
172282
"tooltip": F5TTSCreate.tooltip_seed,
173283
}),
174-
"model": (F5TTSCreate.model_types,),
284+
"model": (model_types,),
285+
# "vocoder": (F5TTSCreate.vocoder_types,),
175286
},
176287
}
177288

@@ -213,7 +324,10 @@ def remove_wave_file(self):
213324
print("F5TTS: Cannot remove? "+self.wave_file_name)
214325
print(e)
215326

216-
def create(self, sample_audio, sample_text, speech, seed=-1, model="F5"):
327+
def create(
328+
self, sample_audio, sample_text, speech, seed=-1, model="F5"
329+
):
330+
vocoder = "vocos"
217331
try:
218332
main_voice = self.load_voice_from_input(sample_audio, sample_text)
219333

@@ -223,7 +337,9 @@ def create(self, sample_audio, sample_text, speech, seed=-1, model="F5"):
223337
chunks = f5ttsCreate.split_text(speech)
224338
voices['main'] = main_voice
225339

226-
audio = f5ttsCreate.create(voices, chunks, seed, model)
340+
audio = f5ttsCreate.create(
341+
voices, chunks, seed, model, vocoder
342+
)
227343
finally:
228344
self.remove_wave_file()
229345
return (audio, )
@@ -243,11 +359,6 @@ class F5TTSAudio:
243359
def __init__(self):
244360
self.use_cli = False
245361

246-
@staticmethod
247-
def get_txt_file_path(file):
248-
p = Path(file)
249-
return os.path.join(os.path.dirname(file), p.stem + ".txt")
250-
251362
@classmethod
252363
def INPUT_TYPES(s):
253364
input_dir = folder_paths.get_input_directory()
@@ -256,11 +367,13 @@ def INPUT_TYPES(s):
256367
)
257368
filesWithTxt = []
258369
for file in files:
259-
txtFile = F5TTSAudio.get_txt_file_path(file)
370+
txtFile = F5TTSCreate.get_txt_file_path(file)
260371
if os.path.isfile(os.path.join(input_dir, txtFile)):
261372
filesWithTxt.append(file)
262373
filesWithTxt = sorted(filesWithTxt)
263374

375+
model_types = F5TTSCreate.get_model_types()
376+
264377
return {
265378
"required": {
266379
"sample": (filesWithTxt, {"audio_upload": True}),
@@ -273,7 +386,8 @@ def INPUT_TYPES(s):
273386
"default": 1, "min": -1,
274387
"tooltip": F5TTSCreate.tooltip_seed,
275388
}),
276-
"model": (F5TTSCreate.model_types,),
389+
"model": (model_types,),
390+
# "vocoder": (F5TTSCreate.vocoder_types,),
277391
}
278392
}
279393

@@ -304,12 +418,14 @@ def load_voice_from_file(self, sample):
304418
input_dir = folder_paths.get_input_directory()
305419
txt_file = os.path.join(
306420
input_dir,
307-
F5TTSAudio.get_txt_file_path(sample)
421+
F5TTSCreate.get_txt_file_path(sample)
308422
)
309423
audio_text = ''
310-
with open(txt_file, 'r') as file:
424+
with open(txt_file, 'r', encoding='utf-8') as file:
311425
audio_text = file.read()
312426
audio_path = folder_paths.get_annotated_filepath(sample)
427+
print("audio_text")
428+
print(audio_text)
313429
return F5TTSCreate.load_voice(audio_path, audio_text)
314430

315431
def load_voices_from_files(self, sample, voice_names):
@@ -330,7 +446,8 @@ def load_voices_from_files(self, sample, voice_names):
330446
voices[voice_name] = self.load_voice_from_file(sample_file)
331447
return voices
332448

333-
def create(self, sample, speech, seed=-1, model="F5"):
449+
def create(self, sample, speech, seed=-2, model="F5"):
450+
vocoder = "vocos"
334451
# Install.check_install()
335452
main_voice = self.load_voice_from_file(sample)
336453

@@ -350,14 +467,14 @@ def create(self, sample, speech, seed=-1, model="F5"):
350467
voices = self.load_voices_from_files(sample, voice_names)
351468
voices['main'] = main_voice
352469

353-
audio = f5ttsCreate.create(voices, chunks, seed, model)
470+
audio = f5ttsCreate.create(voices, chunks, seed, model, vocoder)
354471
return (audio, )
355472

356473
@classmethod
357474
def IS_CHANGED(s, sample, speech, seed, model):
358475
m = hashlib.sha256()
359476
audio_path = folder_paths.get_annotated_filepath(sample)
360-
audio_txt_path = F5TTSAudio.get_txt_file_path(audio_path)
477+
audio_txt_path = F5TTSCreate.get_txt_file_path(audio_path)
361478
last_modified_timestamp = os.path.getmtime(audio_path)
362479
txt_last_modified_timestamp = os.path.getmtime(audio_txt_path)
363480
m.update(audio_path)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-f5-tts"
33
description = "Text to speech with F5-TTS"
4-
version = "1.0.7"
4+
version = "1.0.8"
55
license = {text = "MIT License"}
66

77
[project.urls]

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,3 @@ transformers
2020
vocos
2121
wandb
2222
x_transformers>=1.31.14
23-

0 commit comments

Comments
 (0)