Skip to content

Commit f049e2d

Browse files
authored
AudioLDM2模型复现前向推理 (PaddlePaddle#366)
任务:PaddlePaddle/PaddleMIX#250 - text-to-audio推理已跑通
1 parent 825fdfd commit f049e2d

28 files changed

+7917
-0
lines changed
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# AudioLDM2
2+
3+
## 1. 模型简介
4+
5+
该模型是 [AudioLDM2](https://arxiv.org/abs/2308.05734) 的 paddle 实现。
6+
7+
8+
## 2. Demo
9+
10+
### 2.1 依赖安装
11+
12+
- 请确保已安装 ppdiffusers ([参考方法](https://github.com/PaddlePaddle/PaddleMIX/blob/develop/README.md?plain=1#L62))
13+
14+
- 其余依赖安装:
15+
16+
```bash
17+
cd /paddlemix/models/audioldm2
18+
pip install -r requirement.txt
19+
```
20+
21+
### 2.2 动态图推理
22+
```bash
23+
python run_predict.py \
24+
--text "Musical constellations twinkling in the night sky, forming a cosmic melody." \
25+
--model_name_or_path "/my_model_path" \
26+
--seed 1001 \
27+
```
+302
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass, field
16+
import paddle
17+
from paddlenlp.trainer import PdArgumentParser
18+
import os
19+
import time
20+
import soundfile as sf
21+
from paddlemix.models.audioldm2.modeling import AudioLDM2Model
22+
from paddlemix.models.audioldm2.encoders.phoneme_encoder import text as text
23+
import random
24+
import numpy as np
25+
import re
26+
27+
def seed_everything(seed):
28+
os.environ["PYTHONHASHSEED"] = str(seed)
29+
random.seed(seed)
30+
np.random.seed(seed)
31+
paddle.seed(seed)
32+
33+
def text2phoneme(data):
34+
return text._clean_text(re.sub(r'<.*?>', '', data), ["english_cleaners2"])
35+
36+
def text_to_filename(text):
37+
return text.replace(" ", "_").replace("'", "_").replace('"', "_")
38+
39+
CACHE = {
40+
"get_vits_phoneme_ids":{
41+
"PAD_LENGTH": 310,
42+
"_pad": '_',
43+
"_punctuation": ';:,.!?¡¿—…"«»“” ',
44+
"_letters": 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz',
45+
"_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ",
46+
"_special": "♪☎☒☝⚠"
47+
}
48+
}
49+
50+
CACHE["get_vits_phoneme_ids"]["symbols"] = [CACHE["get_vits_phoneme_ids"]["_pad"]] + list(CACHE["get_vits_phoneme_ids"]["_punctuation"]) + list(CACHE["get_vits_phoneme_ids"]["_letters"]) + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"]) + list(CACHE["get_vits_phoneme_ids"]["_special"])
51+
CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = {s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"])}
52+
53+
def get_vits_phoneme_ids_no_padding(phonemes):
54+
pad_token_id = 0
55+
pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
56+
_symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]
57+
batchsize = len(phonemes)
58+
59+
clean_text = phonemes[0] + "⚠"
60+
sequence = []
61+
62+
for symbol in clean_text:
63+
if(symbol not in _symbol_to_id.keys()):
64+
print("%s is not in the vocabulary. %s" % (symbol, clean_text))
65+
symbol = "_"
66+
symbol_id = _symbol_to_id[symbol]
67+
sequence += [symbol_id]
68+
69+
def _pad_phonemes(phonemes_list):
70+
return phonemes_list + [pad_token_id] * (pad_length-len(phonemes_list))
71+
72+
sequence = sequence[:pad_length]
73+
74+
return {"phoneme_idx": paddle.to_tensor(_pad_phonemes(sequence), dtype="int64").unsqueeze(0).expand([batchsize, -1])}
75+
76+
77+
def make_batch_for_text_to_audio(text, transcription="", waveform=None, fbank=None, batchsize=1):
78+
text = [text] * batchsize
79+
if(transcription):
80+
transcription = text2phoneme(transcription)
81+
transcription = [transcription] * batchsize
82+
83+
if batchsize < 1:
84+
print("Warning: Batchsize must be at least 1. Batchsize is set to .")
85+
86+
if fbank is None:
87+
fbank = paddle.zeros(
88+
(batchsize, 1024, 64)
89+
) # Not used, here to keep the code format
90+
else:
91+
fbank = paddle.to_tensor(fbank, dtype="float32")
92+
fbank = fbank.expand([batchsize, 1024, 64])
93+
assert fbank.shape[0] == batchsize
94+
95+
stft = paddle.zeros((batchsize, 1024, 512)) # Not used
96+
phonemes = get_vits_phoneme_ids_no_padding(transcription)
97+
98+
waveform = paddle.zeros((batchsize, 160000)) # Not used
99+
ta_kaldi_fbank = paddle.zeros((batchsize, 1024, 128))
100+
101+
batch = {
102+
"text": text, # list
103+
"fname": [text_to_filename(t) for t in text], # list
104+
"waveform": waveform,
105+
"stft": stft,
106+
"log_mel_spec": fbank,
107+
"ta_kaldi_fbank": ta_kaldi_fbank,
108+
}
109+
batch.update(phonemes)
110+
return batch
111+
112+
def get_time():
113+
t = time.localtime()
114+
return time.strftime("%d_%m_%Y_%H_%M_%S", t)
115+
116+
def save_wave(waveform, savepath, name="outwav", samplerate=16000):
117+
if type(name) is not list:
118+
name = [name] * waveform.shape[0]
119+
120+
for i in range(waveform.shape[0]):
121+
if waveform.shape[0] > 1:
122+
fname = "%s_%s.wav" % (
123+
os.path.basename(name[i])
124+
if (not ".wav" in name[i])
125+
else os.path.basename(name[i]).split(".")[0],
126+
i,
127+
)
128+
else:
129+
fname = "%s.wav" % os.path.basename(name[i]) if (not ".wav" in name[i]) else os.path.basename(name[i]).split(".")[0]
130+
# Avoid the file name too long to be saved
131+
if len(fname) > 255:
132+
fname = f"{hex(hash(fname))}.wav"
133+
134+
path = os.path.join(
135+
savepath, fname
136+
)
137+
print("Save audio to %s" % path)
138+
sf.write(path, waveform[i, 0], samplerate=samplerate)
139+
140+
def read_list(fname):
141+
result = []
142+
with open(fname, "r", encoding="utf-8") as f:
143+
for each in f.readlines():
144+
each = each.strip('\n')
145+
result.append(each)
146+
return result
147+
148+
def text_to_audio(
149+
model,
150+
text,
151+
transcription="",
152+
seed=42,
153+
ddim_steps=200,
154+
duration=10,
155+
batchsize=1,
156+
guidance_scale=3.5,
157+
n_candidate_gen_per_text=3,
158+
latent_t_per_second=25.6,
159+
):
160+
161+
seed_everything(int(seed))
162+
waveform = None
163+
164+
batch = make_batch_for_text_to_audio(text, transcription=transcription, waveform=waveform, batchsize=batchsize)
165+
166+
model.latent_t_size = int(duration * latent_t_per_second)
167+
168+
waveform = model(
169+
batch,
170+
unconditional_guidance_scale=guidance_scale,
171+
ddim_steps=ddim_steps,
172+
n_gen=n_candidate_gen_per_text,
173+
duration=duration,
174+
)
175+
176+
return waveform
177+
178+
179+
@dataclass
180+
class DataArguments:
181+
"""
182+
Arguments pertaining to what data we are going to input our model for training and eval.
183+
Using `PdArgumentParser` we can turn this class
184+
into argparse arguments to be able to specify them on
185+
the command line.
186+
"""
187+
188+
text: str = field(default="", metadata={"help": "Text prompt to the model for audio generation."})
189+
transcription: str = field(default="", metadata={"help": "Transcription for Text-to-Speech."})
190+
text_list: str = field(default="", metadata={"help": "A file (utf-8 encoded) that contains text prompt to the model for audio generation."})
191+
192+
@dataclass
193+
class ModelArguments:
194+
"""
195+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
196+
"""
197+
198+
model_name_or_path: str = field(
199+
default="audioldm2-full",
200+
metadata={"help": "Path to pretrained model or model identifier"},
201+
)
202+
save_path: str = field(
203+
default="./output",
204+
metadata={"help": "The path to save model output."},
205+
)
206+
device: str = field(
207+
default="gpu",
208+
metadata={"help": "The device for computation. If not specified, the script will automatically choose gpu."},
209+
)
210+
batchsize: int = field(
211+
default=1,
212+
metadata={"help": "Generate how many samples at the same time."},
213+
)
214+
ddim_steps: int = field(
215+
default=200,
216+
metadata={"help": "The sampling step for DDIM."},
217+
)
218+
guidance_scale: float = field(
219+
default=3.5,
220+
metadata={"help": "Guidance scale (Large => better quality and relavancy to text; Small => better diversity)."},
221+
)
222+
duration: float = field(
223+
default=10.0,
224+
metadata={"help": "The duration of the samples."},
225+
)
226+
n_candidate_gen_per_text: int = field(
227+
default=3,
228+
metadata={"help": "Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation."},
229+
)
230+
seed: int = field(
231+
default=42,
232+
metadata={"help": "Change this value (any integer number) will lead to a different generation result."},
233+
)
234+
235+
def main():
236+
parser = PdArgumentParser((ModelArguments, DataArguments))
237+
model_args, data_args = parser.parse_args_into_dataclasses()
238+
239+
# process args
240+
text = data_args.text
241+
transcription = data_args.transcription
242+
text_list = data_args.text_list
243+
244+
save_path = os.path.join(model_args.save_path, get_time())
245+
random_seed = model_args.seed
246+
duration = model_args.duration
247+
sample_rate = 16000
248+
latent_t_per_second = 25.6
249+
250+
print("Warning: For AudioLDM2 we currently only support 10s of generation. Please use audioldm_48k or audioldm_16k_crossattn_t5 if you want a different duration.")
251+
duration = 10
252+
253+
guidance_scale = model_args.guidance_scale
254+
n_candidate_gen_per_text = model_args.n_candidate_gen_per_text
255+
256+
if transcription:
257+
if "speech" not in model_args.model_name_or_path:
258+
print("Warning: You choose to perform Text-to-Speech by providing the transcription. However you do not choose the correct model name (audioldm2-speech-gigaspeech or audioldm2-speech-ljspeech).")
259+
print("Warning: We will use audioldm2-speech-gigaspeech by default")
260+
model_args.model_name_or_path = "audioldm2-speech-gigaspeech"
261+
if not text:
262+
print("Warning: You should provide text as a input to describe the speaker. Use default (A male reporter is speaking).")
263+
text = "A female reporter is speaking full of emotion"
264+
265+
if text_list:
266+
print("Generate audio based on the text prompts in %s" % text_list)
267+
prompt_todo = read_list(text_list)
268+
else:
269+
prompt_todo = [text]
270+
271+
# build audioldm2 model
272+
paddle.set_device(model_args.device)
273+
audioldm2 = AudioLDM2Model.from_pretrained(model_args.model_name_or_path)
274+
275+
# predict
276+
os.makedirs(save_path, exist_ok=True)
277+
for text in prompt_todo:
278+
if "|" in text:
279+
text, name = text.split("|")
280+
else:
281+
name = text[:128]
282+
283+
if transcription:
284+
name += "-TTS-%s" % transcription
285+
286+
waveform = text_to_audio(
287+
audioldm2,
288+
text,
289+
transcription=transcription, # To avoid the model to ignore the last vocab
290+
seed=random_seed,
291+
duration=duration,
292+
guidance_scale=guidance_scale,
293+
ddim_steps=model_args.ddim_steps,
294+
n_candidate_gen_per_text=n_candidate_gen_per_text,
295+
batchsize=model_args.batchsize,
296+
latent_t_per_second=latent_t_per_second
297+
)
298+
299+
save_wave(waveform, save_path, name=name, samplerate=sample_rate)
300+
301+
if __name__ == "__main__":
302+
main()

paddlemix/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,5 @@
2121
from .qwen_vl import *
2222
from .visualglm.configuration import *
2323
from .visualglm.modeling import *
24+
from .audioldm2.modeling import *
25+
from .audioldm2.configuration import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)