Skip to content

feat: Add ONNX support for inference acceleration #276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,60 @@ Some other features include:

The Python API and model cards can be found in [this repo](https://github.com/myshell-ai/MeloTTS/blob/main/docs/install.md#python-api) or on [HuggingFace](https://huggingface.co/myshell-ai).

## ONNX Inference Acceleration

MeloTTS supports ONNX (Open Neural Network Exchange) for potentially faster inference, especially on CPU, and for easier deployment across different platforms.

### Installation for ONNX

The necessary dependencies for ONNX inference (`onnx`, `onnxruntime`) are included in the `requirements.txt` file. You can install them along with other dependencies:

```bash
pip install -r requirements.txt
```

If you have a CUDA-compatible GPU and want to leverage GPU acceleration for ONNX, you might consider installing the GPU-enabled version of `onnxruntime`. After installing the base requirements, you can do:

```bash
pip uninstall onnxruntime
pip install onnxruntime-gpu
```
Make sure your CUDA and cuDNN versions are compatible with the chosen `onnxruntime-gpu` package.

### Exporting a Model to ONNX

To use ONNX, you first need to export a pre-trained MeloTTS model to the ONNX format. This is done using the `export_onnx.py` script.

**Example Command:**

```bash
python export_onnx.py --language EN --output_path melo_en.onnx --device cpu
```

**Arguments:**
- `--language`: (Required) The language of the model to export (e.g., `EN`, `ZH`, `ES`). This determines which default pre-trained model is loaded for export if `--ckpt_path` is not specified.
- `--output_path`: (Optional) The path where the exported ONNX model will be saved. Defaults to `melo.onnx`.
- `--ckpt_path`: (Optional) Path to a specific PyTorch model checkpoint (`.pth` file). If not provided, the script will attempt to download the default pre-trained model for the specified language from HuggingFace.
- `--device`: (Optional) The device to use for loading the PyTorch model during export (e.g., `cpu`, `cuda`). Defaults to `cpu`. The exported ONNX model itself will be runnable on various devices supported by ONNX Runtime.

### Running Inference with an ONNX Model

Once you have an ONNX model file, you can use the `melo/infer.py` script with the `--use_onnx` flag to perform text-to-speech synthesis.

**Example Command:**

```bash
python melo/infer.py --text "Hello world, this is a test using ONNX." --language EN --use_onnx --onnx_path melo_en.onnx --output_dir outputs_onnx
```

**Arguments:**
- `--use_onnx`: A flag to indicate that an ONNX model should be used for inference.
- `--onnx_path`: The path to your `.onnx` model file. Required if `--use_onnx` is specified. Defaults to `melo.onnx`.
- `--language`: The language of the model (must match the language the ONNX model was exported for).
- `--text`: The text you want to synthesize.
- `--output_dir`: The directory where the output audio file(s) will be saved.
- `--config_path`: (Optional) Path to the `config.json` for the model. If not provided, the script will attempt to load/download the default config for the specified language. This is important for `TTS_ONNX` to correctly load model hyperparameters (`hps`).

## Join the Community

**Discord**
Expand Down
95 changes: 95 additions & 0 deletions export_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import os
import argparse
import melo.api

def export_model_to_onnx(language: str, ckpt_path: str = None, output_path: str = 'melo.onnx', device: str = 'cpu'):
"""
Exports the TTS model to ONNX format.

Args:
language (str): Language of the model (e.g., 'EN').
ckpt_path (str, optional): Path to the model checkpoint. Defaults to None.
output_path (str, optional): Path to save the ONNX model. Defaults to 'melo.onnx'.
device (str, optional): Device to use for export (e.g., 'cpu', 'cuda'). Defaults to 'cpu'.
"""
print(f"Loading model for language: {language} with device: {device}")
if ckpt_path:
print(f"Using checkpoint: {ckpt_path}")

TTS_model = melo.api.TTS(language=language, ckpt_path=ckpt_path, device=device, use_hf=True)
model = TTS_model.model
hps = TTS_model.hps

print("Model loaded successfully.")

# Prepare dummy input tensors for model.infer()
x = torch.randint(0, len(hps.symbols), (1, 50), device=device)
x_lengths = torch.LongTensor([50], device=device)
sid = torch.LongTensor([0], device=device)

num_tones_for_dummy = hps.num_tones if hps.num_tones is not None and hps.num_tones > 0 else 1
tone_input = torch.randint(0, num_tones_for_dummy, (1, 50), device=device)

language_input = torch.LongTensor([0], device=device) # Assuming 0 is a valid language ID for the dummy input

# Adjust bert_input and ja_bert_input based on model configuration
# These dimensions might need to be verified against the specific model's expected input shapes
bert_feature_dim = hps.data.text_encoder.inter_channels # Example, verify actual attribute
ja_bert_feature_dim = 768 # Common dimension for Japanese BERT, verify

bert_input = torch.randn(1, bert_feature_dim, 50, device=device) if hps.data.use_bert else torch.zeros(1, bert_feature_dim, 50, device=device)
ja_bert_input = torch.randn(1, ja_bert_feature_dim, 50, device=device) if hps.data.use_japanese_bert else torch.zeros(1, ja_bert_feature_dim, 50, device=device)

# Default values for inference parameters
noise_scale = 0.667
length_scale = 1.0
noise_scale_w = 0.8
max_len = None
sdp_ratio = 0.0 # sdp_ratio is often 0.0 for VITS models if not using specific features

args_for_onnx_export = (
x, x_lengths, sid, tone_input, language_input, bert_input, ja_bert_input,
noise_scale, length_scale, noise_scale_w, max_len, sdp_ratio
)

input_names = ['x', 'x_lengths', 'sid', 'tone', 'language', 'bert', 'ja_bert']
output_names = ['audio_output']

dynamic_axes = {
'x': {1: 'sequence'},
'tone': {1: 'sequence'},
'bert': {2: 'sequence'}, # Assuming bert features align with sequence length
'ja_bert': {2: 'sequence'}, # Assuming ja_bert features align with sequence length
'audio_output': {2: 'audio_length'}
}

print(f"Exporting model to ONNX at {output_path}...")
torch.onnx.export(
model,
args_for_onnx_export,
output_path,
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes
)
print(f"Model exported successfully to {output_path}")

if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Export MeloTTS model to ONNX format.")
parser.add_argument('--language', type=str, required=True, help="Language of the model (e.g., 'EN', 'ZH', 'ES', 'FR', 'JA', 'KO').")
parser.add_argument('--ckpt_path', type=str, default=None, help="Path to the model checkpoint file. If None, downloads the default model for the language.")
parser.add_argument('--output_path', type=str, default='melo.onnx', help="Path to save the exported ONNX model.")
parser.add_argument('--device', type=str, default='cpu', help="Device to use for export (e.g., 'cpu', 'cuda').")

args = parser.parse_args()

export_model_to_onnx(
language=args.language,
ckpt_path=args.ckpt_path,
output_path=args.output_path,
device=args.device
)
128 changes: 127 additions & 1 deletion melo/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from .split_utils import split_sentence
from .mel_processing import spectrogram_torch, spectrogram_torch_conv
from .download_utils import load_or_download_config, load_or_download_model
import onnxruntime


class TTS(nn.Module):
def __init__(self,
def __init__(self,
language,
device='auto',
use_hf=True,
Expand Down Expand Up @@ -133,3 +135,127 @@ def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_s
soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format)
else:
soundfile.write(output_path, audio, self.hps.data.sampling_rate)


class TTS_ONNX(object): # Changed to inherit from object
def __init__(self,
language,
onnx_path,
device='auto',
use_hf=True,
config_path=None):
super().__init__() # Ensure super().__init__() is called
if device == 'auto':
device = 'cpu'
if torch.cuda.is_available(): device = 'cuda'
if torch.backends.mps.is_available(): device = 'mps'

self.device = device # device for text processing, ONNX session handles its own device via providers

self.hps = load_or_download_config(language, use_hf=use_hf, config_path=config_path)
self.symbol_to_id = {s: i for i, s in enumerate(self.hps.symbols)}

_language = language.split('_')[0]
self.language = 'ZH_MIX_EN' if _language == 'ZH' else _language

providers = ['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider', 'CPUExecutionProvider']
self.ort_session = onnxruntime.InferenceSession(onnx_path, providers=providers)
self.onnx_input_names = [inp.name for inp in self.ort_session.get_inputs()]
self.onnx_output_names = [out.name for out in self.ort_session.get_outputs()]


def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None, quiet=False,):
language = self.language
texts = TTS.split_sentences_into_pieces(text, language, quiet)
audio_list = []
if pbar:
tx = pbar(texts)
else:
if position:
tx = tqdm(texts, position=position)
elif quiet:
tx = texts
else:
tx = tqdm(texts)

# Use self.device for text processing, then convert to CPU for ONNX if needed by specific inputs
# The main ONNX session runs on its configured device (CPU or CUDA)
processing_device = self.device

for t in tx:
if language in ['EN', 'ZH_MIX_EN']:
t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)

bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(t, language, self.hps, processing_device, self.symbol_to_id)

# Convert tensors to CPU and then to numpy for ONNX Runtime input
x_tst = phones.unsqueeze(0).cpu().numpy()
tones_tst = tones.unsqueeze(0).cpu().numpy()
lang_ids_tst = lang_ids.unsqueeze(0).cpu().numpy() # Ensure lang_ids is processed correctly
bert_tst = bert.unsqueeze(0).cpu().numpy()
ja_bert_tst = ja_bert.unsqueeze(0).cpu().numpy()
x_tst_lengths = torch.LongTensor([phones.size(0)]).cpu().numpy()
speakers_tst = torch.LongTensor([speaker_id]).cpu().numpy()

input_feed = {
'x': x_tst,
'x_lengths': x_tst_lengths,
'sid': speakers_tst,
'tone': tones_tst,
'language': lang_ids_tst,
'bert': bert_tst,
'ja_bert': ja_bert_tst,
# Parameters for inference, ensure these are part of the ONNX model's inputs if they need to be dynamic
# For now, assuming they are fixed or handled differently in the ONNX model
# 'noise_scale': np.array([noise_scale], dtype=np.float32),
# 'length_scale': np.array([1.0 / speed], dtype=np.float32),
# 'noise_scale_w': np.array([noise_scale_w], dtype=np.float32),
}

# Filter input_feed to only include names expected by the ONNX model
# This is important if the ONNX model doesn't use all defined inputs (e.g. sdp_ratio, noise params are baked in or passed differently)
# The `export_onnx.py` script defines specific inputs.
# Current ONNX model inputs: ['x', 'x_lengths', 'sid', 'tone', 'language', 'bert', 'ja_bert']
# The noise_scale, length_scale, noise_scale_w are passed as direct arguments to model.infer in PyTorch
# In ONNX, these need to be inputs to the graph if they are to be controlled at runtime.
# The export_onnx.py seems to bake them in by not including them as onnx inputs.
# For now, we will only feed the tensor inputs.
# If the ONNX model was exported with noise_scale, etc., as inputs, they should be added here.

# The export_onnx.py script specifies these as inputs to torch.onnx.export:
# args_for_onnx_export = (
# x, x_lengths, sid, tone_input, language_input, bert_input, ja_bert_input,
# noise_scale, length_scale, noise_scale_w, max_len, sdp_ratio <-- these are additional args
# )
# input_names = ['x', 'x_lengths', 'sid', 'tone', 'language', 'bert', 'ja_bert'] <-- but these are the named inputs
# This means noise_scale, length_scale, etc., are passed as *non-tensor arguments* or *attributes* to the ONNX graph's operators if the exporter handles them that way,
# or they are effectively baked into the graph if they were constants during export.
# The current `export_onnx.py` passes them as arguments to `model.infer` which means they are used to compute intermediate values
# *before* the parts of `model.infer` that are captured into the ONNX graph.
# If `model.infer` directly uses these as Python scalars to control flow or shape tensors that are *not* inputs to the ONNX graph,
# then they are effectively baked in.
# The `SynthesizerTrn.infer` method in `models.py` has these as parameters.
# The ONNX export captures the `SynthesizerTrn.infer_onnx` method (if defined) or `SynthesizerTrn.infer`.
# Let's assume the `export_onnx.py` sets up these values as constants or fixed inputs not part of the dynamic `input_feed` for now.
# The provided `export_onnx.py` does pass them as arguments to `torch.onnx.export` for `model` which is `SynthesizerTrn`.
# However, `input_names` only lists the tensor inputs. This implies that `noise_scale` etc. are treated as arguments to the
# *scripted function/module* being exported, not as dynamic inputs to the resulting ONNX graph in the `input_feed` sense.
# They become constants within the ONNX graph.
# So, the current `input_feed` is likely correct.

audio = self.ort_session.run(self.onnx_output_names, input_feed)[0] # Get the first output
audio = audio.squeeze().astype(np.float32) # Remove batch/channel dims, ensure float32
audio_list.append(audio)

if torch.cuda.is_available(): # General cleanup, not specific to ONNX session device
torch.cuda.empty_cache()

audio = TTS.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)

if output_path is None:
return audio
else:
if format:
soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format)
else:
soundfile.write(output_path, audio, self.hps.data.sampling_rate)
67 changes: 53 additions & 14 deletions melo/infer.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,64 @@
import os
import click
from melo.api import TTS
from melo.api import TTS, TTS_ONNX




@click.command()
@click.option('--ckpt_path', '-m', type=str, default=None, help="Path to the checkpoint file")
@click.option('--text', '-t', type=str, default=None, help="Text to speak")
@click.option('--language', '-l', type=str, default="EN", help="Language of the model")
@click.option('--output_dir', '-o', type=str, default="outputs", help="Path to the output")
def main(ckpt_path, text, language, output_dir):
if ckpt_path is None:
raise ValueError("The model_path must be specified")

config_path = os.path.join(os.path.dirname(ckpt_path), 'config.json')
model = TTS(language=language, config_path=config_path, ckpt_path=ckpt_path)
@click.option('--output_dir', '-o', type=str, default="outputs", help="Path to the output directory")
@click.option('--use_onnx/--no-use_onnx', default=False, help="Use ONNX model for inference.")
@click.option('--onnx_path', type=str, default='melo.onnx', help="Path to the ONNX model file (required if --use_onnx).")
@click.option('--ckpt_path', '-m', type=str, default=None, help="Path to the PyTorch checkpoint file (required if not using ONNX).")
@click.option('--config_path', type=str, default=None, help="Path to the model config.json file. Optional, will be inferred or downloaded if not set.")
def main(text, language, output_dir, use_onnx, onnx_path, ckpt_path, config_path):
if text is None:
raise ValueError("The --text option must be specified.")

if use_onnx:
if onnx_path is None: # Should not happen if default is set, but good practice
raise ValueError("The --onnx_path must be specified when using --use_onnx.")
model = TTS_ONNX(
language=language,
onnx_path=onnx_path,
config_path=config_path,
device='auto',
use_hf=True
)
print(f"Using ONNX model: {onnx_path}")
else:
if ckpt_path is None:
raise ValueError("The --ckpt_path must be specified if not using ONNX.")
# If config_path is not provided, TTS will try to infer it from ckpt_path or download
model = TTS(
language=language,
ckpt_path=ckpt_path,
config_path=config_path,
device='auto',
use_hf=True
)
print(f"Using PyTorch model: {ckpt_path}")

for spk_name, spk_id in model.hps.data.spk2id.items():
save_path = f'{output_dir}/{spk_name}/output.wav'
# Ensure model was loaded
if model is None or model.hps is None:
raise RuntimeError("Model could not be loaded. Check paths and configurations.")

# Ensure spk2id is available
if not hasattr(model.hps.data, 'spk2id') or not model.hps.data.spk2id:
# Fallback for models that might not have spk2id (e.g., single-speaker models or different config structure)
# Defaulting to speaker ID 0 for such cases.
# This part might need adjustment based on how speaker IDs are handled in various model configs.
# For now, we assume a default speaker if spk2id is missing.
print("Warning: spk2id not found in model config. Using default speaker ID 0.")
spk_items = {"default_speaker": 0}
else:
spk_items = model.hps.data.spk2id.items()

for spk_name, spk_id in spk_items:
save_path = f'{output_dir}/{language.lower()}_{spk_name}/output.wav'
os.makedirs(os.path.dirname(save_path), exist_ok=True)
model.tts_to_file(text, spk_id, save_path)
print(f"Synthesizing for speaker: {spk_name} (ID: {spk_id})")
model.tts_to_file(text, spk_id, save_path, quiet=True) # Added quiet=True to reduce console noise per synthesis

if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ langid==1.1.6
tqdm
tensorboard==2.16.2
loguru==0.7.2
onnx
onnxruntime
soundfile
Loading