17
17
from comfy .utils import ProgressBar
18
18
from cached_path import cached_path
19
19
sys .path .append (Install .f5TTSPath )
20
- from model import DiT # noqa E402
20
+ from model import DiT , UNetT # noqa E402
21
21
from model .utils_infer import ( # noqa E402
22
22
load_model ,
23
23
preprocess_ref_audio_text ,
28
28
29
29
class F5TTSCreate :
30
30
voice_reg = re .compile (r"\{(\w+)\}" )
31
+ model_types = ["F5" , "E2" ]
31
32
tooltip_seed = "Seed. -1 = random"
32
33
33
34
def is_voice_name (self , word ):
@@ -54,7 +55,33 @@ def load_voice(ref_audio, ref_text):
54
55
)
55
56
return main_voice
56
57
57
- def load_model (self ):
58
+ def load_model (self , model ):
59
+ models = {
60
+ "F5" : self .load_f5_model ,
61
+ "E2" : self .load_e2_model ,
62
+ }
63
+ return models [model ]()
64
+
65
+ def get_vocab_file (self ):
66
+ return os .path .join (
67
+ Install .f5TTSPath , "data/Emilia_ZH_EN_pinyin/vocab.txt"
68
+ )
69
+
70
+ def load_e2_model (self ):
71
+ model_cls = UNetT
72
+ model_cfg = dict (dim = 1024 , depth = 24 , heads = 16 , ff_mult = 4 )
73
+ repo_name = "E2-TTS"
74
+ exp_name = "E2TTS_Base"
75
+ ckpt_step = 1200000
76
+ ckpt_file = str (cached_path (f"hf://SWivid/{ repo_name } /{ exp_name } /model_{ ckpt_step } .safetensors" )) # noqa E501
77
+ vocab_file = self .get_vocab_file ()
78
+ ema_model = load_model (
79
+ model_cls , model_cfg ,
80
+ ckpt_file , vocab_file
81
+ )
82
+ return ema_model
83
+
84
+ def load_f5_model (self ):
58
85
model_cls = DiT
59
86
model_cfg = dict (
60
87
dim = 1024 , depth = 22 , heads = 16 ,
@@ -64,10 +91,11 @@ def load_model(self):
64
91
exp_name = "F5TTS_Base"
65
92
ckpt_step = 1200000
66
93
ckpt_file = str (cached_path (f"hf://SWivid/{ repo_name } /{ exp_name } /model_{ ckpt_step } .safetensors" )) # noqa E501
67
- vocab_file = os .path .join (
68
- Install .f5TTSPath , "data/Emilia_ZH_EN_pinyin/vocab.txt"
94
+ vocab_file = self .get_vocab_file ()
95
+ ema_model = load_model (
96
+ model_cls , model_cfg ,
97
+ ckpt_file , vocab_file
69
98
)
70
- ema_model = load_model (model_cls , model_cfg , ckpt_file , vocab_file )
71
99
return ema_model
72
100
73
101
def generate_audio (self , voices , model_obj , chunks , seed ):
@@ -117,8 +145,8 @@ def generate_audio(self, voices, model_obj, chunks, seed):
117
145
os .unlink (wave_file .name )
118
146
return audio
119
147
120
- def create (self , voices , chunks , seed = - 1 ):
121
- model_obj = self .load_model ()
148
+ def create (self , voices , chunks , seed = - 1 , model = "F5" ):
149
+ model_obj = self .load_model (model )
122
150
return self .generate_audio (voices , model_obj , chunks , seed )
123
151
124
152
@@ -141,6 +169,7 @@ def INPUT_TYPES(s):
141
169
"default" : 1 , "min" : - 1 ,
142
170
"tooltip" : F5TTSCreate .tooltip_seed ,
143
171
}),
172
+ "model" : (F5TTSCreate .model_types ,),
144
173
},
145
174
}
146
175
@@ -174,7 +203,7 @@ def remove_wave_file(self):
174
203
print ("F5TTS: Cannot remove? " + self .wave_file .name )
175
204
print (e )
176
205
177
- def create (self , sample_audio , sample_text , speech , seed = - 1 ):
206
+ def create (self , sample_audio , sample_text , speech , seed = - 1 , model = "F5" ):
178
207
try :
179
208
main_voice = self .load_voice_from_input (sample_audio , sample_text )
180
209
@@ -184,7 +213,7 @@ def create(self, sample_audio, sample_text, speech, seed=-1):
184
213
chunks = f5ttsCreate .split_text (speech )
185
214
voices ['main' ] = main_voice
186
215
187
- audio = f5ttsCreate .create (voices , chunks , seed )
216
+ audio = f5ttsCreate .create (voices , chunks , seed , model )
188
217
finally :
189
218
self .remove_wave_file ()
190
219
return (audio , )
@@ -233,6 +262,7 @@ def INPUT_TYPES(s):
233
262
"default" : 1 , "min" : - 1 ,
234
263
"tooltip" : F5TTSCreate .tooltip_seed ,
235
264
}),
265
+ "model" : (F5TTSCreate .model_types ,),
236
266
}
237
267
}
238
268
@@ -289,7 +319,7 @@ def load_voices_from_files(self, sample, voice_names):
289
319
voices [voice_name ] = self .load_voice_from_file (sample_file )
290
320
return voices
291
321
292
- def create (self , sample , speech , seed = - 1 ):
322
+ def create (self , sample , speech , seed = - 1 , model = "F5" ):
293
323
# Install.check_install()
294
324
main_voice = self .load_voice_from_file (sample )
295
325
@@ -309,7 +339,7 @@ def create(self, sample, speech, seed=-1):
309
339
voices = self .load_voices_from_files (sample , voice_names )
310
340
voices ['main' ] = main_voice
311
341
312
- audio = f5ttsCreate .create (voices , chunks , seed )
342
+ audio = f5ttsCreate .create (voices , chunks , seed , model )
313
343
return (audio , )
314
344
315
345
@classmethod
0 commit comments