Skip to content

Commit 1f76b04

Browse files
committed
Changed speed to speed_type.
Added TDHS speed_type.
1 parent d78b747 commit 1f76b04

File tree

6 files changed

+142
-392
lines changed

6 files changed

+142
-392
lines changed

F5TTS.py

Lines changed: 70 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,19 @@
2121

2222
f5tts_path = os.path.join(Install.f5TTSPath, "src")
2323
sys.path.insert(0, f5tts_path)
24-
from f5_tts.model import DiT,UNetT # noqa E402
25-
from f5_tts.infer.utils_infer import ( # noqa E402
24+
from f5_tts.model import DiT, UNetT # noqa: E402
25+
from f5_tts.infer.utils_infer import ( # noqa: E402
2626
load_model,
2727
load_vocoder,
2828
preprocess_ref_audio_text,
2929
infer_process,
30-
remove_silence_edges,
3130
)
3231
sys.path.remove(f5tts_path)
3332

3433

3534
class F5TTSCreate:
3635
voice_reg = re.compile(r"\{([^\}]+)\}")
36+
default_speed_type = "torch-time-stretch"
3737
model_names = [
3838
"F5",
3939
"F5-HI",
@@ -46,7 +46,7 @@ class F5TTSCreate:
4646
]
4747
vocoder_types = ["auto", "vocos", "bigvgan"]
4848
tooltip_seed = "Seed. -1 = random"
49-
tooltip_speed = "Speed. >1.0 slower. <1.0 faster. Using torchaudio.transforms.TimeStretch" # noqa E501
49+
tooltip_speed = "Speed. >1.0 slower. <1.0 faster"
5050

5151
def get_model_names():
5252
model_names = F5TTSCreate.model_names[:]
@@ -283,9 +283,8 @@ def load_f5_model_url(
283283

284284
def generate_audio(
285285
self, voices, model_obj, chunks, seed, vocoder, mel_spec_type,
286-
speed, infer_args={}
286+
infer_args={}
287287
):
288-
print(voices, model_obj, chunks, seed, vocoder, mel_spec_type, speed, infer_args)
289288
if seed >= 0:
290289
torch.manual_seed(seed)
291290
else:
@@ -325,11 +324,6 @@ def generate_audio(
325324

326325
if generated_audio_segments:
327326
final_wave = np.concatenate(generated_audio_segments)
328-
# if speed != 1.0:
329-
# final_wave = librosa.effects.time_stretch(final_wave, rate=speed)
330-
# wave_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
331-
# sf.write(wave_file.name, final_wave, frame_rate)
332-
# wave_file.close()
333327

334328
# waveform, sample_rate = torchaudio.load(wave_file.name)
335329
waveform = torch.from_numpy(final_wave).unsqueeze(0)
@@ -340,9 +334,18 @@ def generate_audio(
340334
# os.unlink(wave_file.name)
341335
return audio
342336

337+
def time_shift(self, audio, speed, speed_type):
338+
if speed == 1:
339+
return audio
340+
elif speed_type == "TDHS":
341+
return self.time_shift_audiostretchy(audio, speed)
342+
elif speed_type == "torch-time-stretch":
343+
return self.time_shift_torch_ts(audio, speed)
344+
return audio
345+
343346
def create(
344347
self, voices, chunks, seed=-1, model="F5",
345-
vocoder_name="vocos", speed=1,
348+
vocoder_name="vocos",
346349
model_type='F5TTS_Base', infer_args={}
347350
):
348351
(
@@ -355,10 +358,35 @@ def create(
355358
model_obj,
356359
chunks, seed,
357360
vocoder, mel_spec_type=mel_spec_type,
358-
speed=speed, infer_args=infer_args,
361+
infer_args=infer_args,
359362
)
360363

361-
def time_shift(self, audio, speed):
364+
def time_shift_audiostretchy(self, audio, speed):
365+
from audiostretchy.stretch import AudioStretch
366+
367+
rate = audio['sample_rate']
368+
waveform = audio['waveform']
369+
370+
new_waveforms = []
371+
for channel in range(0, waveform.shape[0]):
372+
ta_audio16 = waveform[0][channel] * 32768
373+
374+
audio_stretch = AudioStretch()
375+
audio_stretch.samples = audio_stretch.in_samples = \
376+
ta_audio16.numpy().astype('int16')
377+
audio_stretch.nchannels = 1
378+
audio_stretch.sampwidth = 2
379+
audio_stretch.framerate = rate
380+
audio_stretch.nframes = waveform.shape[2]
381+
audio_stretch.stretch(ratio=speed)
382+
383+
new_waveforms.append(torch.from_numpy(audio_stretch.samples))
384+
new_waveform = torch.stack(new_waveforms)
385+
new_waveform = torch.stack([new_waveform])
386+
387+
return {"waveform": new_waveform, "sample_rate": rate}
388+
389+
def time_shift_torch_ts(self, audio, speed):
362390
import torch_time_stretch
363391
rate = audio['sample_rate']
364392
waveform = audio['waveform']
@@ -382,8 +410,7 @@ def load_voice_from_file(sample):
382410
with open(txt_file, 'r', encoding='utf-8') as file:
383411
audio_text = file.read()
384412
audio_path = folder_paths.get_annotated_filepath(sample)
385-
print("audio_text")
386-
print(audio_text)
413+
print(f"audio_text {audio_text}")
387414
return F5TTSCreate.load_voice(audio_path, audio_text)
388415

389416
@staticmethod
@@ -522,11 +549,12 @@ def create(
522549
voices['main'] = main_voice
523550

524551
audio = f5ttsCreate.create(
525-
voices, chunks, seed, model, vocoder, speed,
552+
voices, chunks, seed, model, vocoder,
526553
model_type
527554
)
528-
if speed != 1:
529-
audio = f5ttsCreate.time_shift(audio, speed)
555+
audio = f5ttsCreate.time_shift(
556+
audio, speed, F5TTSCreate.default_speed_type
557+
)
530558
finally:
531559
if wave_file_name is not None:
532560
F5TTSCreate.remove_wave_file(wave_file_name)
@@ -616,7 +644,7 @@ def INPUT_TYPES(s):
616644

617645
def create(
618646
self,
619-
sample, speech, seed=-2, model="F5", vocoder="vocos",
647+
sample, speech, seed=-1, model="F5", vocoder="vocos",
620648
speed=1,
621649
model_type=None,
622650
):
@@ -632,11 +660,13 @@ def create(
632660
voices['main'] = main_voice
633661

634662
audio = f5ttsCreate.create(
635-
voices, chunks, seed, model, vocoder, speed,
663+
voices, chunks, seed, model, vocoder,
636664
model_type
637665
)
638-
if speed != 1:
639-
audio = f5ttsCreate.time_shift(audio, speed)
666+
667+
audio = f5ttsCreate.time_shift(
668+
audio, speed, F5TTSCreate.default_speed_type
669+
)
640670
return (audio, )
641671

642672
@classmethod
@@ -713,7 +743,7 @@ def INPUT_TYPES(s):
713743
}),
714744
"speed": ("FLOAT", {
715745
"default": 1.0,
716-
"tooltip": F5TTSCreate.tooltip_speed,
746+
"tooltip": F5TTSCreate.tooltip_seed
717747
}),
718748
"model_type": (model_types, {
719749
"tooltip": "Type of model",
@@ -722,25 +752,25 @@ def INPUT_TYPES(s):
722752
},
723753
"optional": {
724754
"sample_audio": ("AUDIO", {
725-
"tooltip": "When this is connected, sample is ignored. Also put the words into sample_text", # noqa E501
755+
"tooltip": "When this is connected, sample is ignored. Also put the words into sample_text", # noqa: E501
726756
}),
727757
"sample_text": ("STRING", {
728758
"default": F5TTSAudioAdvanced.default_sample_text,
729759
"multiline": True,
730760
}),
731761
"target_rms": ("FLOAT", {
732762
"default": 0.1,
733-
"tooltip": "Target output speech loudness normalization value", # noqa E501
763+
"tooltip": "Target output speech loudness normalization value", # noqa: E501
734764
"step": 0.01,
735765
}),
736766
"cross_fade_duration": ("FLOAT", {
737767
"default": 0.15,
738-
"tooltip": "Duration of cross-fade between audio segments in seconds", # noqa E501
768+
"tooltip": "Duration of cross-fade between audio segments in seconds", # noqa: E501
739769
"step": 0.01,
740770
}),
741771
"nfe_step": ("INT", {
742772
"default": 32,
743-
"tooltip": "The number of function evaluation (denoising steps)", # noqa E501
773+
"tooltip": "The number of function evaluation (denoising steps)", # noqa: E501
744774
}),
745775
"cfg_strength": ("FLOAT", {
746776
"default": 2,
@@ -752,14 +782,13 @@ def INPUT_TYPES(s):
752782
"min": -10,
753783
"step": 0.001,
754784
}),
755-
"f5tts_speed": ("FLOAT", {
756-
"default": 1.0,
757-
"tooltip": "The speed of the generated audio. Using F5-TTS. Speed. >1.0 slower. <1.0 faster.", # noqa E501
758-
"step": 0.01,
785+
"speed_type": (["torch-time-stretch", "F5TTS", "TDHS"], {
786+
"default": "torch-time-stretch",
787+
"tooltip": "TDHS - Time-domain harmonic scaling. torch-time-stretch - torchaudio.transforms.TimeStretch. F5TTS's default time stretch. ", # noqa: E501
759788
}),
760789
"fix_duration": ("FLOAT", {
761790
"default": -1,
762-
"tooltip": "Fix the total duration (ref and gen audios) in second. -1 = disable", # noqa E501
791+
"tooltip": "Fix the total duration (ref and gen audios) in second. -1 = disable", # noqa: E501
763792
"min": -1,
764793
"step": 0.01,
765794
}),
@@ -774,7 +803,7 @@ def INPUT_TYPES(s):
774803

775804
def create(
776805
self,
777-
sample, speech, seed=-2, model="F5", vocoder="vocos",
806+
sample, speech, seed=-1, model="F5", vocoder="vocos",
778807
speed=1,
779808
model_type=None,
780809
sample_audio=None,
@@ -784,7 +813,7 @@ def create(
784813
nfe_step=32,
785814
cfg_strength=2,
786815
sway_sampling_coef=-1,
787-
f5tts_speed=1.0,
816+
speed_type="torch-time-stretch",
788817
fix_duration=-1,
789818
):
790819
wave_file_name = None
@@ -817,17 +846,16 @@ def create(
817846
infer_args['nfe_step'] = nfe_step
818847
infer_args['cfg_strength'] = cfg_strength
819848
infer_args['sway_sampling_coef'] = sway_sampling_coef
820-
if (f5tts_speed != 1):
821-
infer_args['speed'] = 1 / f5tts_speed
849+
if (speed_type == "F5TTS" and speed != 1):
850+
infer_args['speed'] = 1 / speed
822851
if (fix_duration >= 0):
823852
infer_args['fix_duration'] = fix_duration
824853

825854
audio = f5ttsCreate.create(
826-
voices, chunks, seed, model, vocoder, speed,
855+
voices, chunks, seed, model, vocoder,
827856
model_type, infer_args
828857
)
829-
if speed != 1:
830-
audio = f5ttsCreate.time_shift(audio, speed)
858+
audio = f5ttsCreate.time_shift(audio, speed, speed_type)
831859
finally:
832860
if wave_file_name is not None:
833861
F5TTSCreate.remove_wave_file(wave_file_name)
@@ -846,7 +874,7 @@ def IS_CHANGED(
846874
nfe_step,
847875
cfg_strength,
848876
sway_sampling_coef,
849-
f5tts_speed,
877+
speed_type,
850878
fix_duration,
851879
):
852880
m = hashlib.sha256()
@@ -870,6 +898,6 @@ def IS_CHANGED(
870898
m.update(nfe_step)
871899
m.update(cfg_strength)
872900
m.update(sway_sampling_coef)
873-
m.update(f5tts_speed)
901+
m.update(speed_type)
874902
m.update(fix_duration)
875903
return m.digest().hex()

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ I haven't tried these...
4242
[Russian](https://huggingface.co/hotstone228/F5-TTS-Russian)
4343
[Spanish](https://huggingface.co/jpgallegoar/F5-Spanish)
4444
[Turkish](https://huggingface.co/marduk-ra/F5-TTS-Turkish)
45+
[Thai](https://huggingface.co/VIZINTZOR/F5-TTS-THAI)
4546
[Vietnamese](https://huggingface.co/yukiakai/F5-TTS-Vietnamese)
4647
[Assamese, Bengali, Gujarati, Hindi, Kannada, Malayalam, Marathi, Odia, Punjabi, Tamil, Telugu](https://huggingface.co/ShriAishu/hindiSpeech)
4748

@@ -106,5 +107,6 @@ pip install -r requirements.txt
106107

107108
### Changes
108109

110+
1.0.22: Added TDHS(Time-domain harmonic scaling) to advanced node.
109111
1.0.21: Added advanced node
110112
1.0.19: Added model\_type.

0 commit comments

Comments
 (0)