12
12
import sys
13
13
import numpy as np
14
14
import re
15
+ import io
15
16
from comfy .utils import ProgressBar
16
17
from cached_path import cached_path
17
18
sys .path .append (Install .f5TTSPath )
24
25
sys .path .pop ()
25
26
26
27
27
- class F5TTSAudio :
28
-
29
- def __init__ (self ):
30
- self .use_cli = False
31
- self .voice_reg = re .compile (r"\{(\w+)\}" )
28
+ class F5TTSCreate :
29
+ voice_reg = re .compile (r"\{(\w+)\}" )
32
30
33
- @staticmethod
34
- def get_txt_file_path (file ):
35
- p = Path (file )
36
- return os .path .join (os .path .dirname (file ), p .stem + ".txt" )
31
+ def is_voice_name (self , word ):
32
+ return self .voice_reg .match (word .strip ())
37
33
38
- @classmethod
39
- def INPUT_TYPES (s ):
40
- input_dir = folder_paths .get_input_directory ()
41
- files = folder_paths .filter_files_content_types (
42
- os .listdir (input_dir ), ["audio" , "video" ]
43
- )
44
- filesWithTxt = []
45
- for file in files :
46
- txtFile = F5TTSAudio .get_txt_file_path (file )
47
- if os .path .isfile (os .path .join (input_dir , txtFile )):
48
- filesWithTxt .append (file )
49
- return {
50
- "required" : {
51
- "sample" : (sorted (filesWithTxt ), {"audio_upload" : True }),
52
- "speech" : ("STRING" , {
53
- "multiline" : True ,
54
- "default" : "Hello World"
55
- }),
56
- }
57
- }
34
+ def get_voice_names (self , chunks ):
35
+ voice_names = {}
36
+ for text in chunks :
37
+ match = self .is_voice_name (text )
38
+ if match :
39
+ voice_names [match [1 ]] = True
40
+ return voice_names
58
41
59
- CATEGORY = "audio"
42
+ def split_text (self , speech ):
43
+ reg1 = r"(?=\{\w+\})"
44
+ return re .split (reg1 , speech )
60
45
61
- RETURN_TYPES = ("AUDIO" , )
62
- FUNCTION = "create"
46
+ @staticmethod
47
+ def load_voice (ref_audio , ref_text ):
48
+ main_voice = {"ref_audio" : ref_audio , "ref_text" : ref_text }
63
49
64
- def create_with_cli (self , audio_path , audio_text , speech , output_dir ):
65
- subprocess .run (
66
- [
67
- "python" , "inference-cli.py" , "--model" , "F5-TTS" ,
68
- "--ref_audio" , audio_path , "--ref_text" , audio_text ,
69
- "--gen_text" , speech ,
70
- "--output_dir" , output_dir
71
- ],
72
- cwd = Install .f5TTSPath
50
+ main_voice ["ref_audio" ], main_voice ["ref_text" ] = preprocess_ref_audio_text ( # noqa E501
51
+ ref_audio , ref_text
73
52
)
74
- output_audio = os .path .join (output_dir , "out.wav" )
75
- with wave .open (output_audio , "rb" ) as wave_file :
76
- frame_rate = wave_file .getframerate ()
77
-
78
- waveform , sample_rate = torchaudio .load (output_audio )
79
- audio = {"waveform" : waveform .unsqueeze (0 ), "sample_rate" : frame_rate }
80
- return audio
53
+ return main_voice
81
54
82
55
def load_model (self ):
83
56
model_cls = DiT
@@ -95,29 +68,6 @@ def load_model(self):
95
68
ema_model = load_model (model_cls , model_cfg , ckpt_file , vocab_file )
96
69
return ema_model
97
70
98
- def load_voice (self , ref_audio , ref_text ):
99
- main_voice = {"ref_audio" : ref_audio , "ref_text" : ref_text }
100
-
101
- main_voice ["ref_audio" ], main_voice ["ref_text" ] = preprocess_ref_audio_text ( # noqa E501
102
- ref_audio , ref_text
103
- )
104
- return main_voice
105
-
106
- def is_voice_name (self , word ):
107
- return self .voice_reg .match (word .strip ())
108
-
109
- def get_voice_names (self , chunks ):
110
- voice_names = {}
111
- for text in chunks :
112
- match = self .is_voice_name (text )
113
- if match :
114
- voice_names [match [1 ]] = True
115
- return voice_names
116
-
117
- def split_text (self , speech ):
118
- reg1 = r"(?=\{\w+\})"
119
- return re .split (reg1 , speech )
120
-
121
71
def generate_audio (self , voices , model_obj , chunks ):
122
72
frame_rate = 44100
123
73
generated_audio_segments = []
@@ -133,7 +83,7 @@ def generate_audio(self, voices, model_obj, chunks):
133
83
if voice not in voices :
134
84
print (f"Voice { voice } not found, using main." )
135
85
voice = "main"
136
- text = self .voice_reg .sub ("" , text )
86
+ text = F5TTSCreate .voice_reg .sub ("" , text )
137
87
gen_text = text .strip ()
138
88
ref_audio = voices [voice ]["ref_audio" ]
139
89
ref_text = voices [voice ]["ref_text" ]
@@ -160,6 +110,137 @@ def generate_audio(self, voices, model_obj, chunks):
160
110
os .unlink (wave_file .name )
161
111
return audio
162
112
113
+ def create (self , voices , chunks ):
114
+ model_obj = self .load_model ()
115
+ return self .generate_audio (voices , model_obj , chunks )
116
+
117
+
118
+ class F5TTSAudioInputs :
119
+ def __init__ (self ):
120
+ self .wave_file = None
121
+
122
+ @classmethod
123
+ def INPUT_TYPES (s ):
124
+ return {
125
+ "required" : {
126
+ "sample_audio" : ("AUDIO" ,),
127
+ "sample_text" : ("STRING" , {"default" : "Text of sample_audio" }),
128
+ "speech" : ("STRING" , {
129
+ "multiline" : True ,
130
+ "default" : "This is what I want to say"
131
+ }),
132
+ },
133
+ }
134
+
135
+ CATEGORY = "audio"
136
+
137
+ RETURN_TYPES = ("AUDIO" , )
138
+ FUNCTION = "create"
139
+
140
+ def load_voice_from_input (self , sample_audio , sample_text ):
141
+ self .wave_file = tempfile .NamedTemporaryFile (
142
+ suffix = ".wav" , delete = False
143
+ )
144
+ for (batch_number , waveform ) in enumerate (
145
+ sample_audio ["waveform" ].cpu ()):
146
+ buff = io .BytesIO ()
147
+ torchaudio .save (
148
+ buff , waveform , sample_audio ["sample_rate" ], format = "WAV"
149
+ )
150
+ with open (self .wave_file .name , 'wb' ) as f :
151
+ f .write (buff .getbuffer ())
152
+ break
153
+ r = F5TTSCreate .load_voice (self .wave_file .name , sample_text )
154
+ return r
155
+
156
+ def remove_wave_file (self ):
157
+ if self .wave_file is not None :
158
+ try :
159
+ os .unlink (self .wave_file .name )
160
+ self .wave_file = None
161
+ except Exception as e :
162
+ print ("F5TTS: Cannot remove? " + self .wave_file .name )
163
+ print (e )
164
+
165
+ def create (self , sample_audio , sample_text , speech ):
166
+ try :
167
+ main_voice = self .load_voice_from_input (sample_audio , sample_text )
168
+
169
+ f5ttsCreate = F5TTSCreate ()
170
+
171
+ voices = {}
172
+ chunks = f5ttsCreate .split_text (speech )
173
+ voices ['main' ] = main_voice
174
+
175
+ audio = f5ttsCreate .create (voices , chunks )
176
+ finally :
177
+ self .remove_wave_file ()
178
+ return (audio , )
179
+
180
+ @classmethod
181
+ def IS_CHANGED (s , sample_audio , sample_text , speech ):
182
+ m = hashlib .sha256 ()
183
+ m .update (sample_text )
184
+ m .update (sample_audio )
185
+ m .update (speech )
186
+ return m .digest ().hex ()
187
+
188
+
189
+ class F5TTSAudio :
190
+ def __init__ (self ):
191
+ self .use_cli = False
192
+
193
+ @staticmethod
194
+ def get_txt_file_path (file ):
195
+ p = Path (file )
196
+ return os .path .join (os .path .dirname (file ), p .stem + ".txt" )
197
+
198
+ @classmethod
199
+ def INPUT_TYPES (s ):
200
+ input_dir = folder_paths .get_input_directory ()
201
+ files = folder_paths .filter_files_content_types (
202
+ os .listdir (input_dir ), ["audio" , "video" ]
203
+ )
204
+ filesWithTxt = []
205
+ for file in files :
206
+ txtFile = F5TTSAudio .get_txt_file_path (file )
207
+ if os .path .isfile (os .path .join (input_dir , txtFile )):
208
+ filesWithTxt .append (file )
209
+ filesWithTxt = sorted (filesWithTxt )
210
+
211
+ return {
212
+ "required" : {
213
+ "sample" : (filesWithTxt , {"audio_upload" : True }),
214
+ "speech" : ("STRING" , {
215
+ "multiline" : True ,
216
+ "default" : "This is what I want to say"
217
+ }),
218
+ }
219
+ }
220
+
221
+ CATEGORY = "audio"
222
+
223
+ RETURN_TYPES = ("AUDIO" , )
224
+ FUNCTION = "create"
225
+
226
+ def create_with_cli (self , audio_path , audio_text , speech , output_dir ):
227
+ subprocess .run (
228
+ [
229
+ "python" , "inference-cli.py" , "--model" , "F5-TTS" ,
230
+ "--ref_audio" , audio_path , "--ref_text" , audio_text ,
231
+ "--gen_text" , speech ,
232
+ "--output_dir" , output_dir
233
+ ],
234
+ cwd = Install .f5TTSPath
235
+ )
236
+ output_audio = os .path .join (output_dir , "out.wav" )
237
+ with wave .open (output_audio , "rb" ) as wave_file :
238
+ frame_rate = wave_file .getframerate ()
239
+
240
+ waveform , sample_rate = torchaudio .load (output_audio )
241
+ audio = {"waveform" : waveform .unsqueeze (0 ), "sample_rate" : frame_rate }
242
+ return audio
243
+
163
244
def load_voice_from_file (self , sample ):
164
245
input_dir = folder_paths .get_input_directory ()
165
246
txt_file = os .path .join (
@@ -170,7 +251,7 @@ def load_voice_from_file(self, sample):
170
251
with open (txt_file , 'r' ) as file :
171
252
audio_text = file .read ()
172
253
audio_path = folder_paths .get_annotated_filepath (sample )
173
- return self .load_voice (audio_path , audio_text )
254
+ return F5TTSCreate .load_voice (audio_path , audio_text )
174
255
175
256
def load_voices_from_files (self , sample , voice_names ):
176
257
voices = {}
@@ -194,6 +275,7 @@ def create(self, sample, speech):
194
275
# Install.check_install()
195
276
main_voice = self .load_voice_from_file (sample )
196
277
278
+ f5ttsCreate = F5TTSCreate ()
197
279
if self .use_cli :
198
280
# working...
199
281
output_dir = tempfile .mkdtemp ()
@@ -204,21 +286,23 @@ def create(self, sample, speech):
204
286
)
205
287
shutil .rmtree (output_dir )
206
288
else :
207
- model_obj = self .load_model ()
208
- chunks = self .split_text (speech )
209
- voice_names = self .get_voice_names (chunks )
289
+ chunks = f5ttsCreate .split_text (speech )
290
+ voice_names = f5ttsCreate .get_voice_names (chunks )
210
291
voices = self .load_voices_from_files (sample , voice_names )
211
292
voices ['main' ] = main_voice
212
293
213
- audio = self . generate_audio (voices , model_obj , chunks )
294
+ audio = f5ttsCreate . create (voices , chunks )
214
295
return (audio , )
215
296
216
297
@classmethod
217
298
def IS_CHANGED (s , sample , speech ):
218
299
m = hashlib .sha256 ()
219
300
audio_path = folder_paths .get_annotated_filepath (sample )
301
+ audio_txt_path = F5TTSAudio .get_txt_file_path (audio_path )
220
302
last_modified_timestamp = os .path .getmtime (audio_path )
303
+ txt_last_modified_timestamp = os .path .getmtime (audio_txt_path )
221
304
m .update (audio_path )
222
305
m .update (str (last_modified_timestamp ))
306
+ m .update (str (txt_last_modified_timestamp ))
223
307
m .update (speech )
224
308
return m .digest ().hex ()
0 commit comments