21
21
22
22
f5tts_path = os .path .join (Install .f5TTSPath , "src" )
23
23
sys .path .insert (0 , f5tts_path )
24
- from f5_tts .model import DiT ,UNetT # noqa E402
25
- from f5_tts .infer .utils_infer import ( # noqa E402
24
+ from f5_tts .model import DiT , UNetT # noqa: E402
25
+ from f5_tts .infer .utils_infer import ( # noqa: E402
26
26
load_model ,
27
27
load_vocoder ,
28
28
preprocess_ref_audio_text ,
29
29
infer_process ,
30
- remove_silence_edges ,
31
30
)
32
31
sys .path .remove (f5tts_path )
33
32
34
33
35
34
class F5TTSCreate :
36
35
voice_reg = re .compile (r"\{([^\}]+)\}" )
36
+ default_speed_type = "torch-time-stretch"
37
37
model_names = [
38
38
"F5" ,
39
39
"F5-HI" ,
@@ -46,7 +46,7 @@ class F5TTSCreate:
46
46
]
47
47
vocoder_types = ["auto" , "vocos" , "bigvgan" ]
48
48
tooltip_seed = "Seed. -1 = random"
49
- tooltip_speed = "Speed. >1.0 slower. <1.0 faster. Using torchaudio.transforms.TimeStretch" # noqa E501
49
+ tooltip_speed = "Speed. >1.0 slower. <1.0 faster"
50
50
51
51
def get_model_names ():
52
52
model_names = F5TTSCreate .model_names [:]
@@ -283,9 +283,8 @@ def load_f5_model_url(
283
283
284
284
def generate_audio (
285
285
self , voices , model_obj , chunks , seed , vocoder , mel_spec_type ,
286
- speed , infer_args = {}
286
+ infer_args = {}
287
287
):
288
- print (voices , model_obj , chunks , seed , vocoder , mel_spec_type , speed , infer_args )
289
288
if seed >= 0 :
290
289
torch .manual_seed (seed )
291
290
else :
@@ -325,11 +324,6 @@ def generate_audio(
325
324
326
325
if generated_audio_segments :
327
326
final_wave = np .concatenate (generated_audio_segments )
328
- # if speed != 1.0:
329
- # final_wave = librosa.effects.time_stretch(final_wave, rate=speed)
330
- # wave_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
331
- # sf.write(wave_file.name, final_wave, frame_rate)
332
- # wave_file.close()
333
327
334
328
# waveform, sample_rate = torchaudio.load(wave_file.name)
335
329
waveform = torch .from_numpy (final_wave ).unsqueeze (0 )
@@ -340,9 +334,18 @@ def generate_audio(
340
334
# os.unlink(wave_file.name)
341
335
return audio
342
336
337
+ def time_shift (self , audio , speed , speed_type ):
338
+ if speed == 1 :
339
+ return audio
340
+ elif speed_type == "TDHS" :
341
+ return self .time_shift_audiostretchy (audio , speed )
342
+ elif speed_type == "torch-time-stretch" :
343
+ return self .time_shift_torch_ts (audio , speed )
344
+ return audio
345
+
343
346
def create (
344
347
self , voices , chunks , seed = - 1 , model = "F5" ,
345
- vocoder_name = "vocos" , speed = 1 ,
348
+ vocoder_name = "vocos" ,
346
349
model_type = 'F5TTS_Base' , infer_args = {}
347
350
):
348
351
(
@@ -355,10 +358,35 @@ def create(
355
358
model_obj ,
356
359
chunks , seed ,
357
360
vocoder , mel_spec_type = mel_spec_type ,
358
- speed = speed , infer_args = infer_args ,
361
+ infer_args = infer_args ,
359
362
)
360
363
361
- def time_shift (self , audio , speed ):
364
+ def time_shift_audiostretchy (self , audio , speed ):
365
+ from audiostretchy .stretch import AudioStretch
366
+
367
+ rate = audio ['sample_rate' ]
368
+ waveform = audio ['waveform' ]
369
+
370
+ new_waveforms = []
371
+ for channel in range (0 , waveform .shape [0 ]):
372
+ ta_audio16 = waveform [0 ][channel ] * 32768
373
+
374
+ audio_stretch = AudioStretch ()
375
+ audio_stretch .samples = audio_stretch .in_samples = \
376
+ ta_audio16 .numpy ().astype ('int16' )
377
+ audio_stretch .nchannels = 1
378
+ audio_stretch .sampwidth = 2
379
+ audio_stretch .framerate = rate
380
+ audio_stretch .nframes = waveform .shape [2 ]
381
+ audio_stretch .stretch (ratio = speed )
382
+
383
+ new_waveforms .append (torch .from_numpy (audio_stretch .samples ))
384
+ new_waveform = torch .stack (new_waveforms )
385
+ new_waveform = torch .stack ([new_waveform ])
386
+
387
+ return {"waveform" : new_waveform , "sample_rate" : rate }
388
+
389
+ def time_shift_torch_ts (self , audio , speed ):
362
390
import torch_time_stretch
363
391
rate = audio ['sample_rate' ]
364
392
waveform = audio ['waveform' ]
@@ -382,8 +410,7 @@ def load_voice_from_file(sample):
382
410
with open (txt_file , 'r' , encoding = 'utf-8' ) as file :
383
411
audio_text = file .read ()
384
412
audio_path = folder_paths .get_annotated_filepath (sample )
385
- print ("audio_text" )
386
- print (audio_text )
413
+ print (f"audio_text { audio_text } " )
387
414
return F5TTSCreate .load_voice (audio_path , audio_text )
388
415
389
416
@staticmethod
@@ -522,11 +549,12 @@ def create(
522
549
voices ['main' ] = main_voice
523
550
524
551
audio = f5ttsCreate .create (
525
- voices , chunks , seed , model , vocoder , speed ,
552
+ voices , chunks , seed , model , vocoder ,
526
553
model_type
527
554
)
528
- if speed != 1 :
529
- audio = f5ttsCreate .time_shift (audio , speed )
555
+ audio = f5ttsCreate .time_shift (
556
+ audio , speed , F5TTSCreate .default_speed_type
557
+ )
530
558
finally :
531
559
if wave_file_name is not None :
532
560
F5TTSCreate .remove_wave_file (wave_file_name )
@@ -616,7 +644,7 @@ def INPUT_TYPES(s):
616
644
617
645
def create (
618
646
self ,
619
- sample , speech , seed = - 2 , model = "F5" , vocoder = "vocos" ,
647
+ sample , speech , seed = - 1 , model = "F5" , vocoder = "vocos" ,
620
648
speed = 1 ,
621
649
model_type = None ,
622
650
):
@@ -632,11 +660,13 @@ def create(
632
660
voices ['main' ] = main_voice
633
661
634
662
audio = f5ttsCreate .create (
635
- voices , chunks , seed , model , vocoder , speed ,
663
+ voices , chunks , seed , model , vocoder ,
636
664
model_type
637
665
)
638
- if speed != 1 :
639
- audio = f5ttsCreate .time_shift (audio , speed )
666
+
667
+ audio = f5ttsCreate .time_shift (
668
+ audio , speed , F5TTSCreate .default_speed_type
669
+ )
640
670
return (audio , )
641
671
642
672
@classmethod
@@ -713,7 +743,7 @@ def INPUT_TYPES(s):
713
743
}),
714
744
"speed" : ("FLOAT" , {
715
745
"default" : 1.0 ,
716
- "tooltip" : F5TTSCreate .tooltip_speed ,
746
+ "tooltip" : F5TTSCreate .tooltip_seed
717
747
}),
718
748
"model_type" : (model_types , {
719
749
"tooltip" : "Type of model" ,
@@ -722,25 +752,25 @@ def INPUT_TYPES(s):
722
752
},
723
753
"optional" : {
724
754
"sample_audio" : ("AUDIO" , {
725
- "tooltip" : "When this is connected, sample is ignored. Also put the words into sample_text" , # noqa E501
755
+ "tooltip" : "When this is connected, sample is ignored. Also put the words into sample_text" , # noqa: E501
726
756
}),
727
757
"sample_text" : ("STRING" , {
728
758
"default" : F5TTSAudioAdvanced .default_sample_text ,
729
759
"multiline" : True ,
730
760
}),
731
761
"target_rms" : ("FLOAT" , {
732
762
"default" : 0.1 ,
733
- "tooltip" : "Target output speech loudness normalization value" , # noqa E501
763
+ "tooltip" : "Target output speech loudness normalization value" , # noqa: E501
734
764
"step" : 0.01 ,
735
765
}),
736
766
"cross_fade_duration" : ("FLOAT" , {
737
767
"default" : 0.15 ,
738
- "tooltip" : "Duration of cross-fade between audio segments in seconds" , # noqa E501
768
+ "tooltip" : "Duration of cross-fade between audio segments in seconds" , # noqa: E501
739
769
"step" : 0.01 ,
740
770
}),
741
771
"nfe_step" : ("INT" , {
742
772
"default" : 32 ,
743
- "tooltip" : "The number of function evaluation (denoising steps)" , # noqa E501
773
+ "tooltip" : "The number of function evaluation (denoising steps)" , # noqa: E501
744
774
}),
745
775
"cfg_strength" : ("FLOAT" , {
746
776
"default" : 2 ,
@@ -752,14 +782,13 @@ def INPUT_TYPES(s):
752
782
"min" : - 10 ,
753
783
"step" : 0.001 ,
754
784
}),
755
- "f5tts_speed" : ("FLOAT" , {
756
- "default" : 1.0 ,
757
- "tooltip" : "The speed of the generated audio. Using F5-TTS. Speed. >1.0 slower. <1.0 faster." , # noqa E501
758
- "step" : 0.01 ,
785
+ "speed_type" : (["torch-time-stretch" , "F5TTS" , "TDHS" ], {
786
+ "default" : "torch-time-stretch" ,
787
+ "tooltip" : "TDHS - Time-domain harmonic scaling. torch-time-stretch - torchaudio.transforms.TimeStretch. F5TTS's default time stretch. " , # noqa: E501
759
788
}),
760
789
"fix_duration" : ("FLOAT" , {
761
790
"default" : - 1 ,
762
- "tooltip" : "Fix the total duration (ref and gen audios) in second. -1 = disable" , # noqa E501
791
+ "tooltip" : "Fix the total duration (ref and gen audios) in second. -1 = disable" , # noqa: E501
763
792
"min" : - 1 ,
764
793
"step" : 0.01 ,
765
794
}),
@@ -774,7 +803,7 @@ def INPUT_TYPES(s):
774
803
775
804
def create (
776
805
self ,
777
- sample , speech , seed = - 2 , model = "F5" , vocoder = "vocos" ,
806
+ sample , speech , seed = - 1 , model = "F5" , vocoder = "vocos" ,
778
807
speed = 1 ,
779
808
model_type = None ,
780
809
sample_audio = None ,
@@ -784,7 +813,7 @@ def create(
784
813
nfe_step = 32 ,
785
814
cfg_strength = 2 ,
786
815
sway_sampling_coef = - 1 ,
787
- f5tts_speed = 1.0 ,
816
+ speed_type = "torch-time-stretch" ,
788
817
fix_duration = - 1 ,
789
818
):
790
819
wave_file_name = None
@@ -817,17 +846,16 @@ def create(
817
846
infer_args ['nfe_step' ] = nfe_step
818
847
infer_args ['cfg_strength' ] = cfg_strength
819
848
infer_args ['sway_sampling_coef' ] = sway_sampling_coef
820
- if (f5tts_speed != 1 ):
821
- infer_args ['speed' ] = 1 / f5tts_speed
849
+ if (speed_type == "F5TTS" and speed != 1 ):
850
+ infer_args ['speed' ] = 1 / speed
822
851
if (fix_duration >= 0 ):
823
852
infer_args ['fix_duration' ] = fix_duration
824
853
825
854
audio = f5ttsCreate .create (
826
- voices , chunks , seed , model , vocoder , speed ,
855
+ voices , chunks , seed , model , vocoder ,
827
856
model_type , infer_args
828
857
)
829
- if speed != 1 :
830
- audio = f5ttsCreate .time_shift (audio , speed )
858
+ audio = f5ttsCreate .time_shift (audio , speed , speed_type )
831
859
finally :
832
860
if wave_file_name is not None :
833
861
F5TTSCreate .remove_wave_file (wave_file_name )
@@ -846,7 +874,7 @@ def IS_CHANGED(
846
874
nfe_step ,
847
875
cfg_strength ,
848
876
sway_sampling_coef ,
849
- f5tts_speed ,
877
+ speed_type ,
850
878
fix_duration ,
851
879
):
852
880
m = hashlib .sha256 ()
@@ -870,6 +898,6 @@ def IS_CHANGED(
870
898
m .update (nfe_step )
871
899
m .update (cfg_strength )
872
900
m .update (sway_sampling_coef )
873
- m .update (f5tts_speed )
901
+ m .update (speed_type )
874
902
m .update (fix_duration )
875
903
return m .digest ().hex ()
0 commit comments