16
16
import io
17
17
from comfy .utils import ProgressBar
18
18
from cached_path import cached_path
19
- sys .path .append (Install .f5TTSPath )
20
- from model import DiT ,UNetT # noqa E402
21
- from model .utils_infer import ( # noqa E402
19
+ sys .path .append (os . path . join ( Install .f5TTSPath , "src" ) )
20
+ from f5_tts . model import DiT ,UNetT # noqa E402
21
+ from f5_tts . infer .utils_infer import ( # noqa E402
22
22
load_model ,
23
+ load_vocoder ,
23
24
preprocess_ref_audio_text ,
24
25
infer_process ,
25
26
)
28
29
29
30
class F5TTSCreate :
30
31
voice_reg = re .compile (r"\{([^\}]+)\}" )
31
- model_types = ["F5" , "E2" ]
32
+ model_types = ["F5" , "F5-JP" , "F5-FR" , "E2" ]
33
+ vocoder_types = ["vocos" , "bigvgan" ]
32
34
tooltip_seed = "Seed. -1 = random"
33
35
36
+ def get_model_types ():
37
+ model_types = F5TTSCreate .model_types [:]
38
+ models_path = folder_paths .get_folder_paths ("checkpoints" )
39
+ for model_path in models_path :
40
+ f5_model_path = os .path .join (model_path , 'F5-TTS' )
41
+ if os .path .isdir (f5_model_path ):
42
+ for file in os .listdir (f5_model_path ):
43
+ p = Path (file )
44
+ if (
45
+ p .suffix in folder_paths .supported_pt_extensions
46
+ and os .path .isfile (os .path .join (f5_model_path , file ))
47
+ ):
48
+ txtFile = F5TTSCreate .get_txt_file_path (
49
+ os .path .join (f5_model_path , file )
50
+ )
51
+
52
+ if (
53
+ os .path .isfile (txtFile )
54
+ ):
55
+ model_types .append ("model://" + file )
56
+ return model_types
57
+
58
+ @staticmethod
59
+ def get_txt_file_path (file ):
60
+ p = Path (file )
61
+ return os .path .join (os .path .dirname (file ), p .stem + ".txt" )
62
+
34
63
def is_voice_name (self , word ):
35
64
return self .voice_reg .match (word .strip ())
36
65
@@ -55,50 +84,118 @@ def load_voice(ref_audio, ref_text):
55
84
)
56
85
return main_voice
57
86
58
- def load_model (self , model ):
59
- models = {
87
+ def get_model_funcs (self ):
88
+ return {
60
89
"F5" : self .load_f5_model ,
90
+ "F5-JP" : self .load_f5_model_jp ,
91
+ "F5-FR" : self .load_f5_model_fr ,
61
92
"E2" : self .load_e2_model ,
62
93
}
63
- return models [model ]()
94
+
95
+ def get_vocoder (self , vocoder_name ):
96
+ if vocoder_name == "vocos" :
97
+ os .path .join (Install .f5TTSPath , "checkpoints/vocos-mel-24khz" )
98
+ elif vocoder_name == "bigvgan" :
99
+ os .path .join (Install .f5TTSPath , "checkpoints/bigvgan_v2_24khz_100band_256x" ) # noqa E501
100
+
101
+ def load_vocoder (self , vocoder_name ):
102
+ return load_vocoder (vocoder_name = vocoder_name )
103
+
104
+ def load_model (self , model , vocoder_name ):
105
+ model_funcs = self .get_model_funcs ()
106
+ if model in model_funcs :
107
+ return model_funcs [model ](vocoder_name )
108
+ else :
109
+ return self .load_f5_model_url (model , vocoder_name )
64
110
65
111
def get_vocab_file (self ):
66
112
return os .path .join (
67
113
Install .f5TTSPath , "data/Emilia_ZH_EN_pinyin/vocab.txt"
68
114
)
69
115
70
- def load_e2_model (self ):
116
+ def load_e2_model (self , vocoder ):
71
117
model_cls = UNetT
72
118
model_cfg = dict (dim = 1024 , depth = 24 , heads = 16 , ff_mult = 4 )
73
119
repo_name = "E2-TTS"
74
120
exp_name = "E2TTS_Base"
75
121
ckpt_step = 1200000
76
122
ckpt_file = str (cached_path (f"hf://SWivid/{ repo_name } /{ exp_name } /model_{ ckpt_step } .safetensors" )) # noqa E501
77
123
vocab_file = self .get_vocab_file ()
124
+ vocoder_name = "vocos"
78
125
ema_model = load_model (
79
126
model_cls , model_cfg ,
80
- ckpt_file , vocab_file
127
+ ckpt_file , vocab_file = vocab_file ,
128
+ mel_spec_type = vocoder_name ,
129
+ )
130
+ vocoder = self .load_vocoder (vocoder_name )
131
+ return (ema_model , vocoder , vocoder_name )
132
+
133
+ def load_f5_model (self , vocoder ):
134
+ repo_name = "F5-TTS"
135
+ if vocoder == "bigvgan" :
136
+ exp_name = "F5TTS_Base_bigvgan"
137
+ ckpt_step = 1250000
138
+ else :
139
+ exp_name = "F5TTS_Base"
140
+ ckpt_step = 1200000
141
+ return self .load_f5_model_url (
142
+ f"hf://SWivid/{ repo_name } /{ exp_name } /model_{ ckpt_step } .safetensors" , # noqa E501
143
+ vocoder ,
144
+ )
145
+
146
+ def load_f5_model_jp (self , vocoder ):
147
+ return self .load_f5_model_url (
148
+ "hf://Jmica/F5TTS/JA_8500000/model_8499660.pt" ,
149
+ vocoder ,
150
+ "hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt"
81
151
)
82
- return ema_model
83
152
84
- def load_f5_model (self ):
153
+ def load_f5_model_fr (self , vocoder ):
154
+ return self .load_f5_model_url (
155
+ "hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_1374000.pt" , # noqa E501
156
+ vocoder ,
157
+ "hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt" # noqa E501
158
+ )
159
+
160
+ def cached_path (self , url ):
161
+ if url .startswith ("model:" ):
162
+ path = re .sub ("^model:/*" , "" , url )
163
+ models_path = folder_paths .get_folder_paths ("checkpoints" )
164
+ for model_path in models_path :
165
+ f5_model_path = os .path .join (model_path , 'F5-TTS' )
166
+ model_file = os .path .join (f5_model_path , path )
167
+ if os .path .isfile (model_file ):
168
+ return model_file
169
+ raise FileNotFoundError ("No model found: " + url )
170
+ return None
171
+ return str (cached_path (url )) # noqa E501
172
+
173
+ def load_f5_model_url (self , url , vocoder_name , vocab_url = None ):
174
+ vocoder = self .load_vocoder (vocoder_name )
85
175
model_cls = DiT
86
176
model_cfg = dict (
87
177
dim = 1024 , depth = 22 , heads = 16 ,
88
178
ff_mult = 2 , text_dim = 512 , conv_layers = 4
89
179
)
90
- repo_name = "F5-TTS"
91
- exp_name = "F5TTS_Base"
92
- ckpt_step = 1200000
93
- ckpt_file = str (cached_path (f"hf://SWivid/{ repo_name } /{ exp_name } /model_{ ckpt_step } .safetensors" )) # noqa E501
94
- vocab_file = self .get_vocab_file ()
180
+ ckpt_file = str (self .cached_path (url )) # noqa E501
181
+
182
+ if vocab_url is None :
183
+ if url .startswith ("model:" ):
184
+ vocab_file = F5TTSCreate .get_txt_file_path (ckpt_file )
185
+ else :
186
+ vocab_file = self .get_vocab_file ()
187
+ else :
188
+ vocab_file = str (self .cached_path (vocab_url ))
95
189
ema_model = load_model (
96
190
model_cls , model_cfg ,
97
- ckpt_file , vocab_file
191
+ ckpt_file , vocab_file = vocab_file ,
192
+ mel_spec_type = vocoder_name ,
98
193
)
99
- return ema_model
194
+ return ( ema_model , vocoder , vocoder_name )
100
195
101
- def generate_audio (self , voices , model_obj , chunks , seed ):
196
+ def generate_audio (
197
+ self , voices , model_obj , chunks , seed , vocoder , mel_spec_type
198
+ ):
102
199
if seed >= 0 :
103
200
torch .manual_seed (seed )
104
201
else :
@@ -127,7 +224,8 @@ def generate_audio(self, voices, model_obj, chunks, seed):
127
224
print (f"Voice: { voice } " )
128
225
print ("text:" + text )
129
226
audio , final_sample_rate , spectragram = infer_process (
130
- ref_audio , ref_text , gen_text , model_obj
227
+ ref_audio , ref_text , gen_text , model_obj ,
228
+ vocoder = vocoder , mel_spec_type = mel_spec_type
131
229
)
132
230
generated_audio_segments .append (audio )
133
231
frame_rate = final_sample_rate
@@ -147,9 +245,20 @@ def generate_audio(self, voices, model_obj, chunks, seed):
147
245
os .unlink (wave_file .name )
148
246
return audio
149
247
150
- def create (self , voices , chunks , seed = - 1 , model = "F5" ):
151
- model_obj = self .load_model (model )
152
- return self .generate_audio (voices , model_obj , chunks , seed )
248
+ def create (
249
+ self , voices , chunks , seed = - 1 , model = "F5" , vocoder_name = "vocos"
250
+ ):
251
+ (
252
+ model_obj ,
253
+ vocoder ,
254
+ mel_spec_type
255
+ ) = self .load_model (model , vocoder_name )
256
+ return self .generate_audio (
257
+ voices ,
258
+ model_obj ,
259
+ chunks , seed ,
260
+ vocoder , mel_spec_type = mel_spec_type ,
261
+ )
153
262
154
263
155
264
class F5TTSAudioInputs :
@@ -158,6 +267,7 @@ def __init__(self):
158
267
159
268
@classmethod
160
269
def INPUT_TYPES (s ):
270
+ model_types = F5TTSCreate .get_model_types ()
161
271
return {
162
272
"required" : {
163
273
"sample_audio" : ("AUDIO" ,),
@@ -171,7 +281,8 @@ def INPUT_TYPES(s):
171
281
"default" : 1 , "min" : - 1 ,
172
282
"tooltip" : F5TTSCreate .tooltip_seed ,
173
283
}),
174
- "model" : (F5TTSCreate .model_types ,),
284
+ "model" : (model_types ,),
285
+ # "vocoder": (F5TTSCreate.vocoder_types,),
175
286
},
176
287
}
177
288
@@ -213,7 +324,10 @@ def remove_wave_file(self):
213
324
print ("F5TTS: Cannot remove? " + self .wave_file_name )
214
325
print (e )
215
326
216
- def create (self , sample_audio , sample_text , speech , seed = - 1 , model = "F5" ):
327
+ def create (
328
+ self , sample_audio , sample_text , speech , seed = - 1 , model = "F5"
329
+ ):
330
+ vocoder = "vocos"
217
331
try :
218
332
main_voice = self .load_voice_from_input (sample_audio , sample_text )
219
333
@@ -223,7 +337,9 @@ def create(self, sample_audio, sample_text, speech, seed=-1, model="F5"):
223
337
chunks = f5ttsCreate .split_text (speech )
224
338
voices ['main' ] = main_voice
225
339
226
- audio = f5ttsCreate .create (voices , chunks , seed , model )
340
+ audio = f5ttsCreate .create (
341
+ voices , chunks , seed , model , vocoder
342
+ )
227
343
finally :
228
344
self .remove_wave_file ()
229
345
return (audio , )
@@ -243,11 +359,6 @@ class F5TTSAudio:
243
359
def __init__ (self ):
244
360
self .use_cli = False
245
361
246
- @staticmethod
247
- def get_txt_file_path (file ):
248
- p = Path (file )
249
- return os .path .join (os .path .dirname (file ), p .stem + ".txt" )
250
-
251
362
@classmethod
252
363
def INPUT_TYPES (s ):
253
364
input_dir = folder_paths .get_input_directory ()
@@ -256,11 +367,13 @@ def INPUT_TYPES(s):
256
367
)
257
368
filesWithTxt = []
258
369
for file in files :
259
- txtFile = F5TTSAudio .get_txt_file_path (file )
370
+ txtFile = F5TTSCreate .get_txt_file_path (file )
260
371
if os .path .isfile (os .path .join (input_dir , txtFile )):
261
372
filesWithTxt .append (file )
262
373
filesWithTxt = sorted (filesWithTxt )
263
374
375
+ model_types = F5TTSCreate .get_model_types ()
376
+
264
377
return {
265
378
"required" : {
266
379
"sample" : (filesWithTxt , {"audio_upload" : True }),
@@ -273,7 +386,8 @@ def INPUT_TYPES(s):
273
386
"default" : 1 , "min" : - 1 ,
274
387
"tooltip" : F5TTSCreate .tooltip_seed ,
275
388
}),
276
- "model" : (F5TTSCreate .model_types ,),
389
+ "model" : (model_types ,),
390
+ # "vocoder": (F5TTSCreate.vocoder_types,),
277
391
}
278
392
}
279
393
@@ -304,12 +418,14 @@ def load_voice_from_file(self, sample):
304
418
input_dir = folder_paths .get_input_directory ()
305
419
txt_file = os .path .join (
306
420
input_dir ,
307
- F5TTSAudio .get_txt_file_path (sample )
421
+ F5TTSCreate .get_txt_file_path (sample )
308
422
)
309
423
audio_text = ''
310
- with open (txt_file , 'r' ) as file :
424
+ with open (txt_file , 'r' , encoding = 'utf-8' ) as file :
311
425
audio_text = file .read ()
312
426
audio_path = folder_paths .get_annotated_filepath (sample )
427
+ print ("audio_text" )
428
+ print (audio_text )
313
429
return F5TTSCreate .load_voice (audio_path , audio_text )
314
430
315
431
def load_voices_from_files (self , sample , voice_names ):
@@ -330,7 +446,8 @@ def load_voices_from_files(self, sample, voice_names):
330
446
voices [voice_name ] = self .load_voice_from_file (sample_file )
331
447
return voices
332
448
333
- def create (self , sample , speech , seed = - 1 , model = "F5" ):
449
+ def create (self , sample , speech , seed = - 2 , model = "F5" ):
450
+ vocoder = "vocos"
334
451
# Install.check_install()
335
452
main_voice = self .load_voice_from_file (sample )
336
453
@@ -350,14 +467,14 @@ def create(self, sample, speech, seed=-1, model="F5"):
350
467
voices = self .load_voices_from_files (sample , voice_names )
351
468
voices ['main' ] = main_voice
352
469
353
- audio = f5ttsCreate .create (voices , chunks , seed , model )
470
+ audio = f5ttsCreate .create (voices , chunks , seed , model , vocoder )
354
471
return (audio , )
355
472
356
473
@classmethod
357
474
def IS_CHANGED (s , sample , speech , seed , model ):
358
475
m = hashlib .sha256 ()
359
476
audio_path = folder_paths .get_annotated_filepath (sample )
360
- audio_txt_path = F5TTSAudio .get_txt_file_path (audio_path )
477
+ audio_txt_path = F5TTSCreate .get_txt_file_path (audio_path )
361
478
last_modified_timestamp = os .path .getmtime (audio_path )
362
479
txt_last_modified_timestamp = os .path .getmtime (audio_txt_path )
363
480
m .update (audio_path )
0 commit comments