Skip to content

Commit 62e4d7b

Browse files
committed
Added input from audio + text node.
1 parent 6288654 commit 62e4d7b

File tree

5 files changed

+426
-83
lines changed

5 files changed

+426
-83
lines changed

F5TTS.py

Lines changed: 161 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import sys
1313
import numpy as np
1414
import re
15+
import io
1516
from comfy.utils import ProgressBar
1617
from cached_path import cached_path
1718
sys.path.append(Install.f5TTSPath)
@@ -24,60 +25,32 @@
2425
sys.path.pop()
2526

2627

27-
class F5TTSAudio:
28-
29-
def __init__(self):
30-
self.use_cli = False
31-
self.voice_reg = re.compile(r"\{(\w+)\}")
28+
class F5TTSCreate:
29+
voice_reg = re.compile(r"\{(\w+)\}")
3230

33-
@staticmethod
34-
def get_txt_file_path(file):
35-
p = Path(file)
36-
return os.path.join(os.path.dirname(file), p.stem + ".txt")
31+
def is_voice_name(self, word):
32+
return self.voice_reg.match(word.strip())
3733

38-
@classmethod
39-
def INPUT_TYPES(s):
40-
input_dir = folder_paths.get_input_directory()
41-
files = folder_paths.filter_files_content_types(
42-
os.listdir(input_dir), ["audio", "video"]
43-
)
44-
filesWithTxt = []
45-
for file in files:
46-
txtFile = F5TTSAudio.get_txt_file_path(file)
47-
if os.path.isfile(os.path.join(input_dir, txtFile)):
48-
filesWithTxt.append(file)
49-
return {
50-
"required": {
51-
"sample": (sorted(filesWithTxt), {"audio_upload": True}),
52-
"speech": ("STRING", {
53-
"multiline": True,
54-
"default": "Hello World"
55-
}),
56-
}
57-
}
34+
def get_voice_names(self, chunks):
35+
voice_names = {}
36+
for text in chunks:
37+
match = self.is_voice_name(text)
38+
if match:
39+
voice_names[match[1]] = True
40+
return voice_names
5841

59-
CATEGORY = "audio"
42+
def split_text(self, speech):
43+
reg1 = r"(?=\{\w+\})"
44+
return re.split(reg1, speech)
6045

61-
RETURN_TYPES = ("AUDIO", )
62-
FUNCTION = "create"
46+
@staticmethod
47+
def load_voice(ref_audio, ref_text):
48+
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
6349

64-
def create_with_cli(self, audio_path, audio_text, speech, output_dir):
65-
subprocess.run(
66-
[
67-
"python", "inference-cli.py", "--model", "F5-TTS",
68-
"--ref_audio", audio_path, "--ref_text", audio_text,
69-
"--gen_text", speech,
70-
"--output_dir", output_dir
71-
],
72-
cwd=Install.f5TTSPath
50+
main_voice["ref_audio"], main_voice["ref_text"] = preprocess_ref_audio_text( # noqa E501
51+
ref_audio, ref_text
7352
)
74-
output_audio = os.path.join(output_dir, "out.wav")
75-
with wave.open(output_audio, "rb") as wave_file:
76-
frame_rate = wave_file.getframerate()
77-
78-
waveform, sample_rate = torchaudio.load(output_audio)
79-
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": frame_rate}
80-
return audio
53+
return main_voice
8154

8255
def load_model(self):
8356
model_cls = DiT
@@ -95,29 +68,6 @@ def load_model(self):
9568
ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
9669
return ema_model
9770

98-
def load_voice(self, ref_audio, ref_text):
99-
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
100-
101-
main_voice["ref_audio"], main_voice["ref_text"] = preprocess_ref_audio_text( # noqa E501
102-
ref_audio, ref_text
103-
)
104-
return main_voice
105-
106-
def is_voice_name(self, word):
107-
return self.voice_reg.match(word.strip())
108-
109-
def get_voice_names(self, chunks):
110-
voice_names = {}
111-
for text in chunks:
112-
match = self.is_voice_name(text)
113-
if match:
114-
voice_names[match[1]] = True
115-
return voice_names
116-
117-
def split_text(self, speech):
118-
reg1 = r"(?=\{\w+\})"
119-
return re.split(reg1, speech)
120-
12171
def generate_audio(self, voices, model_obj, chunks):
12272
frame_rate = 44100
12373
generated_audio_segments = []
@@ -133,7 +83,7 @@ def generate_audio(self, voices, model_obj, chunks):
13383
if voice not in voices:
13484
print(f"Voice {voice} not found, using main.")
13585
voice = "main"
136-
text = self.voice_reg.sub("", text)
86+
text = F5TTSCreate.voice_reg.sub("", text)
13787
gen_text = text.strip()
13888
ref_audio = voices[voice]["ref_audio"]
13989
ref_text = voices[voice]["ref_text"]
@@ -160,6 +110,137 @@ def generate_audio(self, voices, model_obj, chunks):
160110
os.unlink(wave_file.name)
161111
return audio
162112

113+
def create(self, voices, chunks):
114+
model_obj = self.load_model()
115+
return self.generate_audio(voices, model_obj, chunks)
116+
117+
118+
class F5TTSAudioInputs:
119+
def __init__(self):
120+
self.wave_file = None
121+
122+
@classmethod
123+
def INPUT_TYPES(s):
124+
return {
125+
"required": {
126+
"sample_audio": ("AUDIO",),
127+
"sample_text": ("STRING", {"default": "Text of sample_audio"}),
128+
"speech": ("STRING", {
129+
"multiline": True,
130+
"default": "This is what I want to say"
131+
}),
132+
},
133+
}
134+
135+
CATEGORY = "audio"
136+
137+
RETURN_TYPES = ("AUDIO", )
138+
FUNCTION = "create"
139+
140+
def load_voice_from_input(self, sample_audio, sample_text):
141+
self.wave_file = tempfile.NamedTemporaryFile(
142+
suffix=".wav", delete=False
143+
)
144+
for (batch_number, waveform) in enumerate(
145+
sample_audio["waveform"].cpu()):
146+
buff = io.BytesIO()
147+
torchaudio.save(
148+
buff, waveform, sample_audio["sample_rate"], format="WAV"
149+
)
150+
with open(self.wave_file.name, 'wb') as f:
151+
f.write(buff.getbuffer())
152+
break
153+
r = F5TTSCreate.load_voice(self.wave_file.name, sample_text)
154+
return r
155+
156+
def remove_wave_file(self):
157+
if self.wave_file is not None:
158+
try:
159+
os.unlink(self.wave_file.name)
160+
self.wave_file = None
161+
except Exception as e:
162+
print("F5TTS: Cannot remove? "+self.wave_file.name)
163+
print(e)
164+
165+
def create(self, sample_audio, sample_text, speech):
166+
try:
167+
main_voice = self.load_voice_from_input(sample_audio, sample_text)
168+
169+
f5ttsCreate = F5TTSCreate()
170+
171+
voices = {}
172+
chunks = f5ttsCreate.split_text(speech)
173+
voices['main'] = main_voice
174+
175+
audio = f5ttsCreate.create(voices, chunks)
176+
finally:
177+
self.remove_wave_file()
178+
return (audio, )
179+
180+
@classmethod
181+
def IS_CHANGED(s, sample_audio, sample_text, speech):
182+
m = hashlib.sha256()
183+
m.update(sample_text)
184+
m.update(sample_audio)
185+
m.update(speech)
186+
return m.digest().hex()
187+
188+
189+
class F5TTSAudio:
190+
def __init__(self):
191+
self.use_cli = False
192+
193+
@staticmethod
194+
def get_txt_file_path(file):
195+
p = Path(file)
196+
return os.path.join(os.path.dirname(file), p.stem + ".txt")
197+
198+
@classmethod
199+
def INPUT_TYPES(s):
200+
input_dir = folder_paths.get_input_directory()
201+
files = folder_paths.filter_files_content_types(
202+
os.listdir(input_dir), ["audio", "video"]
203+
)
204+
filesWithTxt = []
205+
for file in files:
206+
txtFile = F5TTSAudio.get_txt_file_path(file)
207+
if os.path.isfile(os.path.join(input_dir, txtFile)):
208+
filesWithTxt.append(file)
209+
filesWithTxt = sorted(filesWithTxt)
210+
211+
return {
212+
"required": {
213+
"sample": (filesWithTxt, {"audio_upload": True}),
214+
"speech": ("STRING", {
215+
"multiline": True,
216+
"default": "This is what I want to say"
217+
}),
218+
}
219+
}
220+
221+
CATEGORY = "audio"
222+
223+
RETURN_TYPES = ("AUDIO", )
224+
FUNCTION = "create"
225+
226+
def create_with_cli(self, audio_path, audio_text, speech, output_dir):
227+
subprocess.run(
228+
[
229+
"python", "inference-cli.py", "--model", "F5-TTS",
230+
"--ref_audio", audio_path, "--ref_text", audio_text,
231+
"--gen_text", speech,
232+
"--output_dir", output_dir
233+
],
234+
cwd=Install.f5TTSPath
235+
)
236+
output_audio = os.path.join(output_dir, "out.wav")
237+
with wave.open(output_audio, "rb") as wave_file:
238+
frame_rate = wave_file.getframerate()
239+
240+
waveform, sample_rate = torchaudio.load(output_audio)
241+
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": frame_rate}
242+
return audio
243+
163244
def load_voice_from_file(self, sample):
164245
input_dir = folder_paths.get_input_directory()
165246
txt_file = os.path.join(
@@ -170,7 +251,7 @@ def load_voice_from_file(self, sample):
170251
with open(txt_file, 'r') as file:
171252
audio_text = file.read()
172253
audio_path = folder_paths.get_annotated_filepath(sample)
173-
return self.load_voice(audio_path, audio_text)
254+
return F5TTSCreate.load_voice(audio_path, audio_text)
174255

175256
def load_voices_from_files(self, sample, voice_names):
176257
voices = {}
@@ -194,6 +275,7 @@ def create(self, sample, speech):
194275
# Install.check_install()
195276
main_voice = self.load_voice_from_file(sample)
196277

278+
f5ttsCreate = F5TTSCreate()
197279
if self.use_cli:
198280
# working...
199281
output_dir = tempfile.mkdtemp()
@@ -204,21 +286,23 @@ def create(self, sample, speech):
204286
)
205287
shutil.rmtree(output_dir)
206288
else:
207-
model_obj = self.load_model()
208-
chunks = self.split_text(speech)
209-
voice_names = self.get_voice_names(chunks)
289+
chunks = f5ttsCreate.split_text(speech)
290+
voice_names = f5ttsCreate.get_voice_names(chunks)
210291
voices = self.load_voices_from_files(sample, voice_names)
211292
voices['main'] = main_voice
212293

213-
audio = self.generate_audio(voices, model_obj, chunks)
294+
audio = f5ttsCreate.create(voices, chunks)
214295
return (audio, )
215296

216297
@classmethod
217298
def IS_CHANGED(s, sample, speech):
218299
m = hashlib.sha256()
219300
audio_path = folder_paths.get_annotated_filepath(sample)
301+
audio_txt_path = F5TTSAudio.get_txt_file_path(audio_path)
220302
last_modified_timestamp = os.path.getmtime(audio_path)
303+
txt_last_modified_timestamp = os.path.getmtime(audio_txt_path)
221304
m.update(audio_path)
222305
m.update(str(last_modified_timestamp))
306+
m.update(str(txt_last_modified_timestamp))
223307
m.update(speech)
224308
return m.digest().hex()

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ Using F5-TTS https://github.com/SWivid/F5-TTS
1010
* Press refresh to see it in the node
1111

1212
You can use the examples here...
13-
* [examples voices](examples/)
14-
* [simple workflow](examples/simple_ComfyUI_F5TTS_workflow.json)
13+
* [Examples voices](examples/)
14+
* [Simple workflow](examples/simple_ComfyUI_F5TTS_workflow.json)
15+
* [Workflow with input audio only, using OpenAI's Whisper to get the text](examples/F5TTS_whisper_workflow.json)
1516

1617

1718
### Multi voices...

__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11

2-
from .F5TTS import F5TTSAudio
2+
from .F5TTS import F5TTSAudio, F5TTSAudioInputs
33

44
NODE_CLASS_MAPPINGS = {
5-
"F5TTSAudio": F5TTSAudio
5+
"F5TTSAudio": F5TTSAudio,
6+
"F5TTSAudioInputs": F5TTSAudioInputs
67
}
78
NODE_DISPLAY_NAME_MAPPINGS = {
8-
"F5TTSAudio": "F5-TTS Audio"
9+
"F5TTSAudio": "F5-TTS Audio",
10+
"F5TTSAudioInputs": "F5-TTS Audio from inputs"
911
}

0 commit comments

Comments
 (0)