14
14
import torch
15
15
import torchaudio as ta
16
16
from chatterbox .tts import ChatterboxTTS
17
-
17
+ from chatterbox . mtl_tts import ChatterboxMultilingualTTS
18
18
import grpc
19
19
20
+ def is_float (s ):
21
+ """Check if a string can be converted to float."""
22
+ try :
23
+ float (s )
24
+ return True
25
+ except ValueError :
26
+ return False
27
+ def is_int (s ):
28
+ """Check if a string can be converted to int."""
29
+ try :
30
+ int (s )
31
+ return True
32
+ except ValueError :
33
+ return False
20
34
21
35
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
22
36
@@ -47,6 +61,28 @@ def LoadModel(self, request, context):
47
61
if not torch .cuda .is_available () and request .CUDA :
48
62
return backend_pb2 .Result (success = False , message = "CUDA is not available" )
49
63
64
+
65
+ options = request .Options
66
+
67
+ # empty dict
68
+ self .options = {}
69
+
70
+ # The options are a list of strings in this form optname:optvalue
71
+ # We are storing all the options in a dict so we can use it later when
72
+ # generating the images
73
+ for opt in options :
74
+ if ":" not in opt :
75
+ continue
76
+ key , value = opt .split (":" )
77
+ # if value is a number, convert it to the appropriate type
78
+ if is_float (value ):
79
+ value = float (value )
80
+ elif is_int (value ):
81
+ value = int (value )
82
+ elif value .lower () in ["true" , "false" ]:
83
+ value = value .lower () == "true"
84
+ self .options [key ] = value
85
+
50
86
self .AudioPath = None
51
87
52
88
if os .path .isabs (request .AudioPath ):
@@ -56,10 +92,14 @@ def LoadModel(self, request, context):
56
92
modelFileBase = os .path .dirname (request .ModelFile )
57
93
# modify LoraAdapter to be relative to modelFileBase
58
94
self .AudioPath = os .path .join (modelFileBase , request .AudioPath )
59
-
60
95
try :
61
96
print ("Preparing models, please wait" , file = sys .stderr )
62
- self .model = ChatterboxTTS .from_pretrained (device = device )
97
+ if "multilingual" in self .options :
98
+ # remove key from options
99
+ del self .options ["multilingual" ]
100
+ self .model = ChatterboxMultilingualTTS .from_pretrained (device = device )
101
+ else :
102
+ self .model = ChatterboxTTS .from_pretrained (device = device )
63
103
except Exception as err :
64
104
return backend_pb2 .Result (success = False , message = f"Unexpected { err = } , { type (err )= } " )
65
105
# Implement your logic here for the LoadModel service
@@ -68,12 +108,18 @@ def LoadModel(self, request, context):
68
108
69
109
def TTS (self , request , context ):
70
110
try :
71
- # Generate audio using ChatterboxTTS
111
+ kwargs = {}
112
+
113
+ if "language" in self .options :
114
+ kwargs ["language_id" ] = self .options ["language" ]
72
115
if self .AudioPath is not None :
73
- wav = self .model .generate (request .text , audio_prompt_path = self .AudioPath )
74
- else :
75
- wav = self .model .generate (request .text )
76
-
116
+ kwargs ["audio_prompt_path" ] = self .AudioPath
117
+
118
+ # add options to kwargs
119
+ kwargs .update (self .options )
120
+
121
+ # Generate audio using ChatterboxTTS
122
+ wav = self .model .generate (request .text , ** kwargs )
77
123
# Save the generated audio
78
124
ta .save (request .dst , wav , self .model .sr )
79
125
0 commit comments