diff --git a/README.md b/README.md index 661ca2125..5174fe83a 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,60 @@ Some other features include: The Python API and model cards can be found in [this repo](https://github.com/myshell-ai/MeloTTS/blob/main/docs/install.md#python-api) or on [HuggingFace](https://huggingface.co/myshell-ai). +## ONNX Inference Acceleration + +MeloTTS supports ONNX (Open Neural Network Exchange) for potentially faster inference, especially on CPU, and for easier deployment across different platforms. + +### Installation for ONNX + +The necessary dependencies for ONNX inference (`onnx`, `onnxruntime`) are included in the `requirements.txt` file. You can install them along with other dependencies: + +```bash +pip install -r requirements.txt +``` + +If you have a CUDA-compatible GPU and want to leverage GPU acceleration for ONNX, you might consider installing the GPU-enabled version of `onnxruntime`. After installing the base requirements, you can do: + +```bash +pip uninstall onnxruntime +pip install onnxruntime-gpu +``` +Make sure your CUDA and cuDNN versions are compatible with the chosen `onnxruntime-gpu` package. + +### Exporting a Model to ONNX + +To use ONNX, you first need to export a pre-trained MeloTTS model to the ONNX format. This is done using the `export_onnx.py` script. + +**Example Command:** + +```bash +python export_onnx.py --language EN --output_path melo_en.onnx --device cpu +``` + +**Arguments:** +- `--language`: (Required) The language of the model to export (e.g., `EN`, `ZH`, `ES`). This determines which default pre-trained model is loaded for export if `--ckpt_path` is not specified. +- `--output_path`: (Optional) The path where the exported ONNX model will be saved. Defaults to `melo.onnx`. +- `--ckpt_path`: (Optional) Path to a specific PyTorch model checkpoint (`.pth` file). If not provided, the script will attempt to download the default pre-trained model for the specified language from HuggingFace. +- `--device`: (Optional) The device to use for loading the PyTorch model during export (e.g., `cpu`, `cuda`). Defaults to `cpu`. The exported ONNX model itself will be runnable on various devices supported by ONNX Runtime. + +### Running Inference with an ONNX Model + +Once you have an ONNX model file, you can use the `melo/infer.py` script with the `--use_onnx` flag to perform text-to-speech synthesis. + +**Example Command:** + +```bash +python melo/infer.py --text "Hello world, this is a test using ONNX." --language EN --use_onnx --onnx_path melo_en.onnx --output_dir outputs_onnx +``` + +**Arguments:** +- `--use_onnx`: A flag to indicate that an ONNX model should be used for inference. +- `--onnx_path`: The path to your `.onnx` model file. Required if `--use_onnx` is specified. Defaults to `melo.onnx`. +- `--language`: The language of the model (must match the language the ONNX model was exported for). +- `--text`: The text you want to synthesize. +- `--output_dir`: The directory where the output audio file(s) will be saved. +- `--config_path`: (Optional) Path to the `config.json` for the model. If not provided, the script will attempt to load/download the default config for the specified language. This is important for `TTS_ONNX` to correctly load model hyperparameters (`hps`). + ## Join the Community **Discord** diff --git a/export_onnx.py b/export_onnx.py new file mode 100644 index 000000000..2ec7c87f4 --- /dev/null +++ b/export_onnx.py @@ -0,0 +1,95 @@ +import torch +import os +import argparse +import melo.api + +def export_model_to_onnx(language: str, ckpt_path: str = None, output_path: str = 'melo.onnx', device: str = 'cpu'): + """ + Exports the TTS model to ONNX format. + + Args: + language (str): Language of the model (e.g., 'EN'). + ckpt_path (str, optional): Path to the model checkpoint. Defaults to None. + output_path (str, optional): Path to save the ONNX model. Defaults to 'melo.onnx'. + device (str, optional): Device to use for export (e.g., 'cpu', 'cuda'). Defaults to 'cpu'. + """ + print(f"Loading model for language: {language} with device: {device}") + if ckpt_path: + print(f"Using checkpoint: {ckpt_path}") + + TTS_model = melo.api.TTS(language=language, ckpt_path=ckpt_path, device=device, use_hf=True) + model = TTS_model.model + hps = TTS_model.hps + + print("Model loaded successfully.") + + # Prepare dummy input tensors for model.infer() + x = torch.randint(0, len(hps.symbols), (1, 50), device=device) + x_lengths = torch.LongTensor([50], device=device) + sid = torch.LongTensor([0], device=device) + + num_tones_for_dummy = hps.num_tones if hps.num_tones is not None and hps.num_tones > 0 else 1 + tone_input = torch.randint(0, num_tones_for_dummy, (1, 50), device=device) + + language_input = torch.LongTensor([0], device=device) # Assuming 0 is a valid language ID for the dummy input + + # Adjust bert_input and ja_bert_input based on model configuration + # These dimensions might need to be verified against the specific model's expected input shapes + bert_feature_dim = hps.data.text_encoder.inter_channels # Example, verify actual attribute + ja_bert_feature_dim = 768 # Common dimension for Japanese BERT, verify + + bert_input = torch.randn(1, bert_feature_dim, 50, device=device) if hps.data.use_bert else torch.zeros(1, bert_feature_dim, 50, device=device) + ja_bert_input = torch.randn(1, ja_bert_feature_dim, 50, device=device) if hps.data.use_japanese_bert else torch.zeros(1, ja_bert_feature_dim, 50, device=device) + + # Default values for inference parameters + noise_scale = 0.667 + length_scale = 1.0 + noise_scale_w = 0.8 + max_len = None + sdp_ratio = 0.0 # sdp_ratio is often 0.0 for VITS models if not using specific features + + args_for_onnx_export = ( + x, x_lengths, sid, tone_input, language_input, bert_input, ja_bert_input, + noise_scale, length_scale, noise_scale_w, max_len, sdp_ratio + ) + + input_names = ['x', 'x_lengths', 'sid', 'tone', 'language', 'bert', 'ja_bert'] + output_names = ['audio_output'] + + dynamic_axes = { + 'x': {1: 'sequence'}, + 'tone': {1: 'sequence'}, + 'bert': {2: 'sequence'}, # Assuming bert features align with sequence length + 'ja_bert': {2: 'sequence'}, # Assuming ja_bert features align with sequence length + 'audio_output': {2: 'audio_length'} + } + + print(f"Exporting model to ONNX at {output_path}...") + torch.onnx.export( + model, + args_for_onnx_export, + output_path, + export_params=True, + opset_version=12, + do_constant_folding=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes + ) + print(f"Model exported successfully to {output_path}") + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Export MeloTTS model to ONNX format.") + parser.add_argument('--language', type=str, required=True, help="Language of the model (e.g., 'EN', 'ZH', 'ES', 'FR', 'JA', 'KO').") + parser.add_argument('--ckpt_path', type=str, default=None, help="Path to the model checkpoint file. If None, downloads the default model for the language.") + parser.add_argument('--output_path', type=str, default='melo.onnx', help="Path to save the exported ONNX model.") + parser.add_argument('--device', type=str, default='cpu', help="Device to use for export (e.g., 'cpu', 'cuda').") + + args = parser.parse_args() + + export_model_to_onnx( + language=args.language, + ckpt_path=args.ckpt_path, + output_path=args.output_path, + device=args.device + ) diff --git a/melo/api.py b/melo/api.py index 236ea8f17..2963b22f2 100644 --- a/melo/api.py +++ b/melo/api.py @@ -16,9 +16,11 @@ from .split_utils import split_sentence from .mel_processing import spectrogram_torch, spectrogram_torch_conv from .download_utils import load_or_download_config, load_or_download_model +import onnxruntime + class TTS(nn.Module): - def __init__(self, + def __init__(self, language, device='auto', use_hf=True, @@ -133,3 +135,127 @@ def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_s soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format) else: soundfile.write(output_path, audio, self.hps.data.sampling_rate) + + +class TTS_ONNX(object): # Changed to inherit from object + def __init__(self, + language, + onnx_path, + device='auto', + use_hf=True, + config_path=None): + super().__init__() # Ensure super().__init__() is called + if device == 'auto': + device = 'cpu' + if torch.cuda.is_available(): device = 'cuda' + if torch.backends.mps.is_available(): device = 'mps' + + self.device = device # device for text processing, ONNX session handles its own device via providers + + self.hps = load_or_download_config(language, use_hf=use_hf, config_path=config_path) + self.symbol_to_id = {s: i for i, s in enumerate(self.hps.symbols)} + + _language = language.split('_')[0] + self.language = 'ZH_MIX_EN' if _language == 'ZH' else _language + + providers = ['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider', 'CPUExecutionProvider'] + self.ort_session = onnxruntime.InferenceSession(onnx_path, providers=providers) + self.onnx_input_names = [inp.name for inp in self.ort_session.get_inputs()] + self.onnx_output_names = [out.name for out in self.ort_session.get_outputs()] + + + def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None, quiet=False,): + language = self.language + texts = TTS.split_sentences_into_pieces(text, language, quiet) + audio_list = [] + if pbar: + tx = pbar(texts) + else: + if position: + tx = tqdm(texts, position=position) + elif quiet: + tx = texts + else: + tx = tqdm(texts) + + # Use self.device for text processing, then convert to CPU for ONNX if needed by specific inputs + # The main ONNX session runs on its configured device (CPU or CUDA) + processing_device = self.device + + for t in tx: + if language in ['EN', 'ZH_MIX_EN']: + t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t) + + bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(t, language, self.hps, processing_device, self.symbol_to_id) + + # Convert tensors to CPU and then to numpy for ONNX Runtime input + x_tst = phones.unsqueeze(0).cpu().numpy() + tones_tst = tones.unsqueeze(0).cpu().numpy() + lang_ids_tst = lang_ids.unsqueeze(0).cpu().numpy() # Ensure lang_ids is processed correctly + bert_tst = bert.unsqueeze(0).cpu().numpy() + ja_bert_tst = ja_bert.unsqueeze(0).cpu().numpy() + x_tst_lengths = torch.LongTensor([phones.size(0)]).cpu().numpy() + speakers_tst = torch.LongTensor([speaker_id]).cpu().numpy() + + input_feed = { + 'x': x_tst, + 'x_lengths': x_tst_lengths, + 'sid': speakers_tst, + 'tone': tones_tst, + 'language': lang_ids_tst, + 'bert': bert_tst, + 'ja_bert': ja_bert_tst, + # Parameters for inference, ensure these are part of the ONNX model's inputs if they need to be dynamic + # For now, assuming they are fixed or handled differently in the ONNX model + # 'noise_scale': np.array([noise_scale], dtype=np.float32), + # 'length_scale': np.array([1.0 / speed], dtype=np.float32), + # 'noise_scale_w': np.array([noise_scale_w], dtype=np.float32), + } + + # Filter input_feed to only include names expected by the ONNX model + # This is important if the ONNX model doesn't use all defined inputs (e.g. sdp_ratio, noise params are baked in or passed differently) + # The `export_onnx.py` script defines specific inputs. + # Current ONNX model inputs: ['x', 'x_lengths', 'sid', 'tone', 'language', 'bert', 'ja_bert'] + # The noise_scale, length_scale, noise_scale_w are passed as direct arguments to model.infer in PyTorch + # In ONNX, these need to be inputs to the graph if they are to be controlled at runtime. + # The export_onnx.py seems to bake them in by not including them as onnx inputs. + # For now, we will only feed the tensor inputs. + # If the ONNX model was exported with noise_scale, etc., as inputs, they should be added here. + + # The export_onnx.py script specifies these as inputs to torch.onnx.export: + # args_for_onnx_export = ( + # x, x_lengths, sid, tone_input, language_input, bert_input, ja_bert_input, + # noise_scale, length_scale, noise_scale_w, max_len, sdp_ratio <-- these are additional args + # ) + # input_names = ['x', 'x_lengths', 'sid', 'tone', 'language', 'bert', 'ja_bert'] <-- but these are the named inputs + # This means noise_scale, length_scale, etc., are passed as *non-tensor arguments* or *attributes* to the ONNX graph's operators if the exporter handles them that way, + # or they are effectively baked into the graph if they were constants during export. + # The current `export_onnx.py` passes them as arguments to `model.infer` which means they are used to compute intermediate values + # *before* the parts of `model.infer` that are captured into the ONNX graph. + # If `model.infer` directly uses these as Python scalars to control flow or shape tensors that are *not* inputs to the ONNX graph, + # then they are effectively baked in. + # The `SynthesizerTrn.infer` method in `models.py` has these as parameters. + # The ONNX export captures the `SynthesizerTrn.infer_onnx` method (if defined) or `SynthesizerTrn.infer`. + # Let's assume the `export_onnx.py` sets up these values as constants or fixed inputs not part of the dynamic `input_feed` for now. + # The provided `export_onnx.py` does pass them as arguments to `torch.onnx.export` for `model` which is `SynthesizerTrn`. + # However, `input_names` only lists the tensor inputs. This implies that `noise_scale` etc. are treated as arguments to the + # *scripted function/module* being exported, not as dynamic inputs to the resulting ONNX graph in the `input_feed` sense. + # They become constants within the ONNX graph. + # So, the current `input_feed` is likely correct. + + audio = self.ort_session.run(self.onnx_output_names, input_feed)[0] # Get the first output + audio = audio.squeeze().astype(np.float32) # Remove batch/channel dims, ensure float32 + audio_list.append(audio) + + if torch.cuda.is_available(): # General cleanup, not specific to ONNX session device + torch.cuda.empty_cache() + + audio = TTS.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed) + + if output_path is None: + return audio + else: + if format: + soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format) + else: + soundfile.write(output_path, audio, self.hps.data.sampling_rate) diff --git a/melo/infer.py b/melo/infer.py index 7ac1de943..42092382c 100644 --- a/melo/infer.py +++ b/melo/infer.py @@ -1,25 +1,64 @@ import os import click -from melo.api import TTS +from melo.api import TTS, TTS_ONNX + - - @click.command() -@click.option('--ckpt_path', '-m', type=str, default=None, help="Path to the checkpoint file") @click.option('--text', '-t', type=str, default=None, help="Text to speak") @click.option('--language', '-l', type=str, default="EN", help="Language of the model") -@click.option('--output_dir', '-o', type=str, default="outputs", help="Path to the output") -def main(ckpt_path, text, language, output_dir): - if ckpt_path is None: - raise ValueError("The model_path must be specified") - - config_path = os.path.join(os.path.dirname(ckpt_path), 'config.json') - model = TTS(language=language, config_path=config_path, ckpt_path=ckpt_path) +@click.option('--output_dir', '-o', type=str, default="outputs", help="Path to the output directory") +@click.option('--use_onnx/--no-use_onnx', default=False, help="Use ONNX model for inference.") +@click.option('--onnx_path', type=str, default='melo.onnx', help="Path to the ONNX model file (required if --use_onnx).") +@click.option('--ckpt_path', '-m', type=str, default=None, help="Path to the PyTorch checkpoint file (required if not using ONNX).") +@click.option('--config_path', type=str, default=None, help="Path to the model config.json file. Optional, will be inferred or downloaded if not set.") +def main(text, language, output_dir, use_onnx, onnx_path, ckpt_path, config_path): + if text is None: + raise ValueError("The --text option must be specified.") + + if use_onnx: + if onnx_path is None: # Should not happen if default is set, but good practice + raise ValueError("The --onnx_path must be specified when using --use_onnx.") + model = TTS_ONNX( + language=language, + onnx_path=onnx_path, + config_path=config_path, + device='auto', + use_hf=True + ) + print(f"Using ONNX model: {onnx_path}") + else: + if ckpt_path is None: + raise ValueError("The --ckpt_path must be specified if not using ONNX.") + # If config_path is not provided, TTS will try to infer it from ckpt_path or download + model = TTS( + language=language, + ckpt_path=ckpt_path, + config_path=config_path, + device='auto', + use_hf=True + ) + print(f"Using PyTorch model: {ckpt_path}") - for spk_name, spk_id in model.hps.data.spk2id.items(): - save_path = f'{output_dir}/{spk_name}/output.wav' + # Ensure model was loaded + if model is None or model.hps is None: + raise RuntimeError("Model could not be loaded. Check paths and configurations.") + + # Ensure spk2id is available + if not hasattr(model.hps.data, 'spk2id') or not model.hps.data.spk2id: + # Fallback for models that might not have spk2id (e.g., single-speaker models or different config structure) + # Defaulting to speaker ID 0 for such cases. + # This part might need adjustment based on how speaker IDs are handled in various model configs. + # For now, we assume a default speaker if spk2id is missing. + print("Warning: spk2id not found in model config. Using default speaker ID 0.") + spk_items = {"default_speaker": 0} + else: + spk_items = model.hps.data.spk2id.items() + + for spk_name, spk_id in spk_items: + save_path = f'{output_dir}/{language.lower()}_{spk_name}/output.wav' os.makedirs(os.path.dirname(save_path), exist_ok=True) - model.tts_to_file(text, spk_id, save_path) + print(f"Synthesizing for speaker: {spk_name} (ID: {spk_id})") + model.tts_to_file(text, spk_id, save_path, quiet=True) # Added quiet=True to reduce console noise per synthesis if __name__ == "__main__": main() diff --git a/requirements.txt b/requirements.txt index a79f61599..6b0dd0aae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,3 +27,6 @@ langid==1.1.6 tqdm tensorboard==2.16.2 loguru==0.7.2 +onnx +onnxruntime +soundfile diff --git a/test/test_onnx_inference.py b/test/test_onnx_inference.py new file mode 100644 index 000000000..443064e79 --- /dev/null +++ b/test/test_onnx_inference.py @@ -0,0 +1,90 @@ +import pytest +import os +import soundfile as sf +import shutil # For cleaning up directories/files + +# Assuming export_onnx.py is in the root and contains export_model_to_onnx function +# Adjust path if necessary, e.g., by adding root to sys.path or making export_onnx a callable module +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +from export_onnx import export_model_to_onnx +from melo.api import TTS_ONNX, TTS # TTS is needed to get a valid speaker_id from hps + +# Define a temporary directory for test artifacts +TEST_ARTIFACTS_DIR = "test_onnx_artifacts" +DEFAULT_ONNX_MODEL_PATH = os.path.join(TEST_ARTIFACTS_DIR, "test_model.onnx") +DEFAULT_OUTPUT_AUDIO_PATH = os.path.join(TEST_ARTIFACTS_DIR, "test_output.wav") +DEFAULT_LANGUAGE = "EN" # Use a language known to have downloadable models + +@pytest.fixture(scope="module") +def onnx_model_path_and_speaker_id(): + # Setup: Create the directory and export ONNX model once per module + if not os.path.exists(TEST_ARTIFACTS_DIR): + os.makedirs(TEST_ARTIFACTS_DIR) + + print(f"Exporting ONNX model for language {DEFAULT_LANGUAGE} to {DEFAULT_ONNX_MODEL_PATH}...") + try: + # We need hps from a PyTorch model to get a speaker ID + # This assumes TTS can download the model for DEFAULT_LANGUAGE + pytorch_model_for_hps = TTS(language=DEFAULT_LANGUAGE, device='cpu', use_hf=True) + speaker_id = list(pytorch_model_for_hps.hps.data.spk2id.values())[0] # Get the first speaker ID + del pytorch_model_for_hps # Free up memory + + export_model_to_onnx( + language=DEFAULT_LANGUAGE, + output_path=DEFAULT_ONNX_MODEL_PATH, + device='cpu' + # ckpt_path can be None if using default downloaded models + ) + print(f"ONNX model exported successfully.") + yield DEFAULT_ONNX_MODEL_PATH, speaker_id + except Exception as e: + print(f"Error during ONNX model export fixture setup: {e}") + # If export fails, try to provide a dummy path to avoid crashing other tests, + # but the test using it will likely fail. + # Or, re-raise to fail fast: + raise + finally: + # Teardown: Remove the artifacts directory after all tests in the module run + # print(f"Cleaning up test artifacts directory: {TEST_ARTIFACTS_DIR}") + # shutil.rmtree(TEST_ARTIFACTS_DIR, ignore_errors=True) + # Let's not clean up immediately to allow inspection of generated files if tests fail. + # Cleanup can be manual or handled by a separate script/CI step. + pass + + +def test_onnx_tts_generation(onnx_model_path_and_speaker_id): + onnx_model_path, speaker_id = onnx_model_path_and_speaker_id + + if not os.path.exists(onnx_model_path): + pytest.fail(f"ONNX model file not found at {onnx_model_path}, skipping test. Export might have failed.") + + print(f"Initializing TTS_ONNX with model: {onnx_model_path}") + onnx_tts_model = TTS_ONNX( + language=DEFAULT_LANGUAGE, # Must match the language model was exported for + onnx_path=onnx_model_path, + device='cpu' + ) + + test_text = "Hello world, this is an ONNX test." + print(f"Generating audio for text: '{test_text}' using ONNX...") + + try: + onnx_tts_model.tts_to_file( + text=test_text, + speaker_id=speaker_id, # Use the fetched speaker_id + output_path=DEFAULT_OUTPUT_AUDIO_PATH, + quiet=True + ) + print(f"Audio generated at: {DEFAULT_OUTPUT_AUDIO_PATH}") + + assert os.path.exists(DEFAULT_OUTPUT_AUDIO_PATH), "Output audio file was not created." + + # Check if the audio file has content + audio_data, sr = sf.read(DEFAULT_OUTPUT_AUDIO_PATH) + assert len(audio_data) > 0, "Output audio file is empty." + assert sr == onnx_tts_model.hps.data.sampling_rate, "Output audio sample rate does not match model." + print("ONNX TTS generation test passed.") + + except Exception as e: + pytest.fail(f"ONNX TTS generation failed: {e}")