Skip to content

Commit 84f751f

Browse files
yt6051556240x45f
and
0x45f
authored
[TTS]vits dygraph to static (#2883)
Co-authored-by: 0x45f <wangzhen45@baidu.com>
1 parent 11bc392 commit 84f751f

File tree

5 files changed

+80
-31
lines changed

5 files changed

+80
-31
lines changed

examples/csmsc/vits/local/synthesize_e2e.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
1818
--phones_dict=dump/phone_id_map.txt \
1919
--output_dir=${train_output_path}/test_e2e \
2020
--text=${BIN_DIR}/../sentences.txt \
21-
--add-blank=${add_blank}
21+
--add-blank=${add_blank} #\
22+
# --inference_dir=${train_output_path}/inference
2223
fi

paddlespeech/t2s/exps/syn_utils.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -452,9 +452,19 @@ def am_to_static(am_inference,
452452
elif am_name == 'tacotron2':
453453
am_inference = jit.to_static(
454454
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
455-
456-
paddle.jit.save(am_inference, os.path.join(inference_dir, am))
457-
am_inference = paddle.jit.load(os.path.join(inference_dir, am))
455+
elif am_name == 'vits':
456+
if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
457+
am_inference = jit.to_static(
458+
am_inference,
459+
input_spec=[
460+
InputSpec([-1], dtype=paddle.int64),
461+
InputSpec([1], dtype=paddle.int64),
462+
])
463+
else:
464+
am_inference = jit.to_static(
465+
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
466+
jit.save(am_inference, os.path.join(inference_dir, am))
467+
am_inference = jit.load(os.path.join(inference_dir, am))
458468
return am_inference
459469

460470

@@ -465,8 +475,8 @@ def voc_to_static(voc_inference,
465475
voc_inference, input_spec=[
466476
InputSpec([-1, 80], dtype=paddle.float32),
467477
])
468-
paddle.jit.save(voc_inference, os.path.join(inference_dir, voc))
469-
voc_inference = paddle.jit.load(os.path.join(inference_dir, voc))
478+
jit.save(voc_inference, os.path.join(inference_dir, voc))
479+
voc_inference = jit.load(os.path.join(inference_dir, voc))
470480
return voc_inference
471481

472482

paddlespeech/t2s/exps/vits/synthesize_e2e.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
from timer import timer
2121
from yacs.config import CfgNode
2222

23+
from paddlespeech.t2s.exps.syn_utils import am_to_static
2324
from paddlespeech.t2s.exps.syn_utils import get_frontend
2425
from paddlespeech.t2s.exps.syn_utils import get_sentences
2526
from paddlespeech.t2s.models.vits import VITS
27+
from paddlespeech.t2s.models.vits import VITSInference
2628
from paddlespeech.t2s.utils import str2bool
2729

2830

2931
def evaluate(args):
30-
3132
# Init body.
3233
with open(args.config) as f:
3334
config = CfgNode(yaml.safe_load(f))
@@ -41,6 +42,9 @@ def evaluate(args):
4142

4243
# frontend
4344
frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict)
45+
# acoustic model
46+
am_name = args.am[:args.am.rindex('_')]
47+
am_dataset = args.am[args.am.rindex('_') + 1:]
4448

4549
spk_num = None
4650
if args.speaker_dict is not None:
@@ -64,6 +68,15 @@ def evaluate(args):
6468
vits.set_state_dict(paddle.load(args.ckpt)["main_params"])
6569
vits.eval()
6670

71+
vits_inference = VITSInference(vits)
72+
# whether dygraph to static
73+
if args.inference_dir:
74+
vits_inference = am_to_static(
75+
am_inference=vits_inference,
76+
am=args.am,
77+
inference_dir=args.inference_dir,
78+
speaker_dict=args.speaker_dict)
79+
6780
output_dir = Path(args.output_dir)
6881
output_dir.mkdir(parents=True, exist_ok=True)
6982
merge_sentences = False
@@ -90,10 +103,12 @@ def evaluate(args):
90103
for i in range(len(phone_ids)):
91104
part_phone_ids = phone_ids[i]
92105
spk_id = None
93-
if spk_num is not None:
106+
if am_dataset in {"aishell3",
107+
"vctk"} and spk_num is not None:
94108
spk_id = paddle.to_tensor(args.spk_id)
95-
out = vits.inference(text=part_phone_ids, sids=spk_id)
96-
wav = out["wav"]
109+
wav = vits_inference(part_phone_ids, spk_id)
110+
else:
111+
wav = vits_inference(part_phone_ids)
97112
if flags == 0:
98113
wav_all = wav
99114
flags = 1
@@ -155,6 +170,11 @@ def parse_args():
155170
type=str2bool,
156171
default=True,
157172
help="whether to add blank between phones")
173+
parser.add_argument(
174+
'--am',
175+
type=str,
176+
default='vits_csmsc',
177+
help='Choose acoustic model type of tts task.')
158178

159179
args = parser.parse_args()
160180
return args

paddlespeech/t2s/models/vits/transform.py

+28-21
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@ def piecewise_rational_quadratic_transform(
3535
inverse=False,
3636
tails=None,
3737
tail_bound=1.0,
38-
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
39-
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
40-
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
38+
# for dygraph-to-static
39+
min_bin_width=1e-3,
40+
min_bin_height=1e-3,
41+
min_derivative=1e-3, ):
4142
if tails is None:
4243
spline_fn = rational_quadratic_spline
4344
spline_kwargs = {}
@@ -74,23 +75,27 @@ def unconstrained_rational_quadratic_spline(
7475
inverse=False,
7576
tails="linear",
7677
tail_bound=1.0,
77-
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
78-
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
79-
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
78+
# for dygraph-to-static
79+
min_bin_width=1e-3,
80+
min_bin_height=1e-3,
81+
min_derivative=1e-3, ):
8082
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
8183
outside_interval_mask = ~inside_interval_mask
82-
83-
outputs = paddle.zeros(paddle.shape(inputs))
84-
logabsdet = paddle.zeros(paddle.shape(inputs))
84+
# for dygraph to static
85+
# 这里用 paddle.shape(x) 然后调用 zeros 会得到一个全 -1 shape 的 var
86+
# 如果用 x.shape 的话可以保留确定的维度
87+
outputs = paddle.zeros(inputs.shape)
88+
logabsdet = paddle.zeros(inputs.shape)
8589
if tails == "linear":
8690
unnormalized_derivatives = F.pad(
8791
unnormalized_derivatives,
8892
pad=[0] * (len(unnormalized_derivatives.shape) - 1) * 2 + [1, 1])
8993
constant = np.log(np.exp(1 - min_derivative) - 1)
9094
unnormalized_derivatives[..., 0] = constant
9195
unnormalized_derivatives[..., -1] = constant
92-
93-
outputs[outside_interval_mask] = inputs[outside_interval_mask]
96+
# for dygraph to static
97+
tmp = inputs[outside_interval_mask]
98+
outputs[outside_interval_mask] = tmp
9499
logabsdet[outside_interval_mask] = 0
95100
else:
96101
raise RuntimeError("{} tails are not implemented.".format(tails))
@@ -130,18 +135,20 @@ def rational_quadratic_spline(
130135
right=1.0,
131136
bottom=0.0,
132137
top=1.0,
133-
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
134-
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
135-
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
136-
if paddle.min(inputs) < left or paddle.max(inputs) > right:
137-
raise ValueError("Input to a transform is not within its domain")
138+
# for dygraph-to-static
139+
min_bin_width=1e-3,
140+
min_bin_height=1e-3,
141+
min_derivative=1e-3, ):
142+
# for dygraph to static
143+
# if paddle.min(inputs) < left or paddle.max(inputs) > right:
144+
# raise ValueError("Input to a transform is not within its domain")
138145

139146
num_bins = unnormalized_widths.shape[-1]
140-
141-
if min_bin_width * num_bins > 1.0:
142-
raise ValueError("Minimal bin width too large for the number of bins")
143-
if min_bin_height * num_bins > 1.0:
144-
raise ValueError("Minimal bin height too large for the number of bins")
147+
# for dygraph to static
148+
# if min_bin_width * num_bins > 1.0:
149+
# raise ValueError("Minimal bin width too large for the number of bins")
150+
# if min_bin_height * num_bins > 1.0:
151+
# raise ValueError("Minimal bin height too large for the number of bins")
145152

146153
widths = F.softmax(unnormalized_widths, axis=-1)
147154
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths

paddlespeech/t2s/models/vits/vits.py

+11
Original file line numberDiff line numberDiff line change
@@ -532,3 +532,14 @@ def _reset_parameters(module):
532532
module.weight[module._padding_idx] = 0
533533

534534
self.apply(_reset_parameters)
535+
536+
class VITSInference(nn.Layer):
537+
def __init__(self, model):
538+
super().__init__()
539+
self.acoustic_model = model
540+
541+
def forward(self, text, sids=None):
542+
out = self.acoustic_model.inference(
543+
text, sids=sids)
544+
wav = out['wav']
545+
return wav

0 commit comments

Comments
 (0)