Skip to content

Commit 2a8a753

Browse files
committed
Update for the future release
1 parent 1813207 commit 2a8a753

File tree

9 files changed

+1383
-252
lines changed

9 files changed

+1383
-252
lines changed

README.md

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ curl -o util/lars.py https://raw.githubusercontent.com/facebookresearch/mae/efb2
4949
curl -o util/lr_decay.py https://raw.githubusercontent.com/facebookresearch/mae/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/lr_decay.py
5050
curl -o util/lr_sched.py https://raw.githubusercontent.com/facebookresearch/mae/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/lr_sched.py
5151
curl -o util/misc.py https://raw.githubusercontent.com/facebookresearch/mae/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/misc.py
52+
curl -o util/analyze_repr.py https://raw.githubusercontent.com/daisukelab/general-learning/master/SSL/analyze_repr.py
5253
curl -o m2d/pos_embed.py https://raw.githubusercontent.com/facebookresearch/mae/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py
5354
curl -o train_audio.py https://raw.githubusercontent.com/facebookresearch/mae/efb2a8062c206524e35e47d04501ed4f544c0ae8/main_pretrain.py
5455
curl -o train_speech.py https://raw.githubusercontent.com/facebookresearch/mae/efb2a8062c206524e35e47d04501ed4f544c0ae8/main_pretrain.py
@@ -75,19 +76,16 @@ We have a utility runtime model utility, RuntimeM2D, that helps you to load a pr
7576
```python
7677
from m2d.runtime_audio import RuntimeM2D
7778

78-
device = torch.device('cpu') # set 'cuda' if you run on a GPU
79-
8079
# Prepare your batch of audios. This is a dummy example of three 10s waves.
8180
batch_audio = 2 * torch.rand((3, 10 * 16000)) - 1.0 # input range = [-1., 1]
8281
batch_audio = batch_audio.to(device)
8382

8483
# Create a model with pretrained weights.
8584
runtime = RuntimeM2D(weight_file='m2d_vit_base-80x608p16x16-220930-mr7/checkpoint-300.pth')
86-
runtime = runtime.to(device)
8785

8886
# Encode raw audio into features. `encode()` will do the followings automatically:
8987
# 1. Convert the input `batch_audio` to log-mel spectrograms (LMS).
90-
# 2. Normalize the batch LMS with mean and std calculated from the batch.
88+
# 2. Normalize the batch LMS with mean and std used in the pre-training.
9189
# 3. Encode the bach LMS to features.
9290
frame_level = runtime.encode(batch_audio)
9391

@@ -101,31 +99,6 @@ clip_level = torch.mean(frame_level, dim=1)
10199
print(clip_level.shape)
102100
```
103101

104-
To get the best features, you can normalize your audio with normalization statistics of your entire input data and use them in your pipeline.
105-
106-
```python
107-
# Calculate statistics in advance. This is an example with 10 random waves.
108-
means, stds = [], []
109-
for _ in range(10):
110-
lms = runtime.to_feature(torch.rand((10 * 16000)).to(device))
111-
means.append(lms.mean())
112-
stds.append(lms.std())
113-
114-
dataset_mean, dataset_std = torch.mean(torch.stack(means)), torch.mean(torch.stack(stds))
115-
# These can be numbers [-5.4919195, 5.0389895], for example.
116-
117-
# The followings are an example pipeline.
118-
119-
# Convert your batch audios into LMS.
120-
batch_lms = runtime.to_feature(batch_audio)
121-
# Normalize them.
122-
batch_lms = (batch_lms - dataset_mean) / (dataset_std + torch.finfo().eps)
123-
# Encode them to feame-level features.
124-
frame_level = runtime.encode_lms(batch_lms)
125-
# Calculate clip-level features if needed.
126-
clip_level = torch.mean(frame_level, dim=1)
127-
```
128-
129102
To get features per layer, you can add `return_layers=True`.
130103

131104
```python

all_eval.sh

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
cd evar
2-
CUDA_VISIBLE_DEVICES=0 python 2pass_lineareval.py config/m2d.yaml cremad batch_size=16,weight_file=$1
3-
CUDA_VISIBLE_DEVICES=0 python 2pass_lineareval.py config/m2d.yaml gtzan batch_size=16,weight_file=$1
4-
CUDA_VISIBLE_DEVICES=0 python 2pass_lineareval.py config/m2d.yaml spcv2 batch_size=64,weight_file=$1
5-
CUDA_VISIBLE_DEVICES=0 python 2pass_lineareval.py config/m2d.yaml esc50 batch_size=64,weight_file=$1
6-
CUDA_VISIBLE_DEVICES=0 python 2pass_lineareval.py config/m2d.yaml us8k batch_size=64,weight_file=$1
7-
CUDA_VISIBLE_DEVICES=0 python 2pass_lineareval.py config/m2d.yaml vc1 batch_size=64,weight_file=$1
8-
CUDA_VISIBLE_DEVICES=0 python 2pass_lineareval.py config/m2d.yaml voxforge batch_size=64,weight_file=$1
9-
CUDA_VISIBLE_DEVICES=0 python 2pass_lineareval.py config/m2d.yaml nsynth batch_size=64,weight_file=$1
10-
CUDA_VISIBLE_DEVICES=0 python 2pass_lineareval.py config/m2d.yaml surge batch_size=64,weight_file=$1
2+
GPU=0
3+
4+
if [[ "$1" == *'p32k-'* ]]; then
5+
cfg='config/m2d_32k.yaml'
6+
else
7+
cfg='config/m2d.yaml'
8+
fi
9+
10+
CUDA_VISIBLE_DEVICES=$GPU python 2pass_lineareval.py $cfg cremad batch_size=16,weight_file=$1
11+
CUDA_VISIBLE_DEVICES=$GPU python 2pass_lineareval.py $cfg gtzan batch_size=16,weight_file=$1
12+
CUDA_VISIBLE_DEVICES=$GPU python 2pass_lineareval.py $cfg spcv2 batch_size=64,weight_file=$1
13+
CUDA_VISIBLE_DEVICES=$GPU python 2pass_lineareval.py $cfg esc50 batch_size=64,weight_file=$1
14+
CUDA_VISIBLE_DEVICES=$GPU python 2pass_lineareval.py $cfg us8k batch_size=64,weight_file=$1
15+
CUDA_VISIBLE_DEVICES=$GPU python 2pass_lineareval.py $cfg vc1 batch_size=64,weight_file=$1
16+
CUDA_VISIBLE_DEVICES=$GPU python 2pass_lineareval.py $cfg voxforge batch_size=64,weight_file=$1
17+
CUDA_VISIBLE_DEVICES=$GPU python 2pass_lineareval.py $cfg nsynth batch_size=64,weight_file=$1
18+
CUDA_VISIBLE_DEVICES=$GPU python 2pass_lineareval.py $cfg surge batch_size=64,weight_file=$1
19+
1120
python summarize.py $1

audio_dataset.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,128 @@ def build_viz_dataset(cfg):
152152
norm_stats = cfg.norm_stats if 'norm_stats' in cfg else None
153153
ds = SpectrogramDataset(folder=cfg.data_path, files=files, crop_frames=cfg.input_size[1], tfms=None, norm_stats=norm_stats)
154154
return ds, files
155+
156+
157+
# Mixed dataset
158+
159+
def log_mixup_exp(xa, xb, alpha):
160+
xa = xa.exp()
161+
xb = xb.exp()
162+
x = alpha * xa + (1. - alpha) * xb
163+
return torch.log(torch.max(x, torch.finfo(x.dtype).eps*torch.ones_like(x)))
164+
165+
166+
class MixedSpecDataset(torch.utils.data.Dataset):
167+
def __init__(self, base_folder, files_main, files_bg_noise, crop_size, noise_ratio=0.0,
168+
random_crop=True, n_norm_calc=10000) -> None:
169+
super().__init__()
170+
171+
self.ds1 = SpectrogramDataset(folder=base_folder, files=files_main, crop_frames=crop_size[1],
172+
random_crop=random_crop, norm_stats=None,
173+
n_norm_calc=n_norm_calc//2)
174+
self.norm_stats = self.ds1.norm_stats # for compatibility with SpectrogramDataset
175+
# disable normalizion scaling in the ds1
176+
self.norm_std = self.ds1.norm_stats[1]
177+
self.ds1.norm_stats = (self.ds1.norm_stats[0], 1.0)
178+
179+
if noise_ratio > 0.0:
180+
self.ds2 = SpectrogramDataset(folder=base_folder, files=files_bg_noise, crop_frames=crop_size[1],
181+
random_crop=random_crop, norm_stats=None, n_norm_calc=n_norm_calc//2, repeat_short=True)
182+
self.ds2.norm_stats = (self.ds2.norm_stats[0], 1.0) # disable normalizion scaling in the ds2
183+
184+
self.noise_ratio = noise_ratio
185+
self.bg_index = []
186+
187+
def __len__(self):
188+
return len(self.ds1)
189+
190+
def __getitem__(self, index, fixed_noise=False):
191+
# load index sample
192+
clean = self.ds1[index]
193+
if self.noise_ratio > 0.0:
194+
# load random noise sample ### , while making noise floor zero
195+
noise = self.ds2[index if fixed_noise else self.get_next_bgidx()]
196+
# mix
197+
mixed = log_mixup_exp(noise, clean, self.noise_ratio) if self.noise_ratio < 1.0 else noise
198+
else:
199+
mixed = clean.clone()
200+
# finish normalization. clean and noise were averaged to zero. the following will scale to 1.0 using ds1 std.
201+
clean = clean / self.norm_std
202+
mixed = mixed / self.norm_std
203+
return clean, mixed
204+
205+
206+
def get_next_bgidx(self):
207+
if len(self.bg_index) == 0:
208+
self.bg_index = torch.randperm(len(self.ds2)).tolist()
209+
# print(f'Refreshed the bg index list with {len(self.bg_index)} items: {self.bg_index[:5]}...')
210+
return self.bg_index.pop(0)
211+
212+
def __repr__(self):
213+
format_string = self.__class__.__name__ + f'(crop_frames={self.ds1.crop_frames}, '
214+
format_string += f'folder_sp={self.ds1.df.file_name.values[0].split("/")[0]}, '
215+
if self.noise_ratio > 0.: format_string += f'folder_bg={self.ds2.df.file_name.values[0].split("/")[0]}, '
216+
return format_string
217+
218+
219+
def inflate_files(files, desired_size):
220+
if len(files) == 0:
221+
return files
222+
files = list(files) # make sure `files`` is a list
223+
while len(files) < desired_size:
224+
files = (files + files)[:desired_size]
225+
return files
226+
227+
228+
def build_mixed_dataset(cfg):
229+
"""The followings configure the training dataset details.
230+
- data_path: Root folder of the training dataset.
231+
- dataset: The _name_ of the training dataset, an stem name of a `.csv` training data list.
232+
- norm_stats: Normalization statistics, a list of [mean, std].
233+
- input_size: Input size, a list of [# of freq. bins, # of time frames].
234+
"""
235+
236+
# get files and inflate the number of files (by repeating the list) if needed
237+
files_main = get_files(cfg.csv_main)
238+
files_bg = get_files(cfg.csv_bg_noise) if cfg.noise_ratio > 0. else []
239+
desired_min_size = 0
240+
if 'min_ds_size' in cfg and cfg.min_ds_size > 0:
241+
desired_min_size = cfg.min_ds_size
242+
if desired_min_size > 0:
243+
old_sizes = len(files_main), len(files_bg)
244+
files_main, files_bg = inflate_files(files_main, desired_min_size), inflate_files(files_bg, desired_min_size)
245+
print('The numbers of data files are increased from', old_sizes, 'to', (len(files_main), len(files_bg)))
246+
247+
ds = MixedSpecDataset(
248+
base_folder=cfg.data_path, files_main=files_main,
249+
files_bg_noise=files_bg,
250+
crop_size=cfg.input_size,
251+
noise_ratio=cfg.noise_ratio,
252+
random_crop=True)
253+
if 'weighted' in cfg and cfg.weighted:
254+
assert desired_min_size == 0
255+
ds.weight = pd.read_csv(cfg.csv_main).weight.values
256+
257+
val_ds = SpectrogramDataset(folder=cfg.data_path, files=get_files(cfg.csv_val), crop_frames=cfg.input_size[1], random_crop=True) \
258+
if cfg.csv_val else None
259+
260+
return ds, val_ds
261+
262+
263+
def build_mixed_viz_dataset(cfg):
264+
files = [str(f).replace(str(cfg.data_path) + '/', '') for f in sorted(Path(cfg.data_path).glob('vis_samples/*.npy'))]
265+
if len(files) == 0:
266+
return None, []
267+
norm_stats = cfg.norm_stats if 'norm_stats' in cfg else None
268+
ds = SpectrogramDataset(folder=cfg.data_path, files=files, crop_frames=cfg.input_size[1], tfms=None, norm_stats=norm_stats)
269+
return ds, files
270+
271+
272+
if __name__ == '__main__':
273+
# Test
274+
ds = MixedSpecDataset(base_folder='data', files_main=get_files('data/files_gtzan.csv'),
275+
files_bg_noise=get_files('data/files_audioset.csv'),
276+
crop_size=[80, 608], noise_ratio=0.2, random_crop=True, n_norm_calc=10)
277+
for i in range(0, 10):
278+
clean, mixed = ds[i]
279+
print(clean.shape, mixed.shape)

examples/Example_1.ipynb

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Short example"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": 1,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"import warnings; warnings.simplefilter('ignore')\n",
17+
"import logging\n",
18+
"logging.basicConfig(level=logging.INFO)\n",
19+
"import sys\n",
20+
"sys.path.append('..')\n",
21+
"import torch"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": 2,
27+
"metadata": {},
28+
"outputs": [
29+
{
30+
"name": "stderr",
31+
"output_type": "stream",
32+
"text": [
33+
"INFO:root:<All keys matched successfully>\n",
34+
"INFO:root:Model input size: [80, 608]\n",
35+
"INFO:root:Using weights: m2d_vit_base-80x608p16x16-220930-mr7/checkpoint-300.pth\n",
36+
"INFO:root:Feature dimension: 3840\n",
37+
"INFO:root:Norm stats: -7.1, 4.2\n",
38+
"INFO:root:Runtime MelSpectrogram(16000, 400, 400, 160, 80, 50, 8000):\n",
39+
"INFO:root:MelSpectrogram(\n",
40+
" Mel filter banks size = (80, 201), trainable_mel=False\n",
41+
" (stft): STFT(n_fft=400, Fourier Kernel size=(201, 1, 400), iSTFT=False, trainable=False)\n",
42+
")\n"
43+
]
44+
},
45+
{
46+
"name": "stdout",
47+
"output_type": "stream",
48+
"text": [
49+
" using 150 parameters, while dropped 250 out of 400 parameters from m2d_vit_base-80x608p16x16-220930-mr7/checkpoint-300.pth\n",
50+
" (dropped: ['mask_token', 'decoder_pos_embed', 'decoder_embed.weight', 'decoder_embed.bias', 'decoder_blocks.0.norm1.weight'] ...)\n",
51+
"<All keys matched successfully>\n"
52+
]
53+
}
54+
],
55+
"source": [
56+
"from portable_m2d import PortableM2D\n",
57+
"weight = 'm2d_vit_base-80x608p16x16-220930-mr7/checkpoint-300.pth'\n",
58+
"model = PortableM2D(weight_file=weight)\n"
59+
]
60+
},
61+
{
62+
"cell_type": "code",
63+
"execution_count": 3,
64+
"metadata": {},
65+
"outputs": [
66+
{
67+
"name": "stdout",
68+
"output_type": "stream",
69+
"text": [
70+
"torch.Size([1, 63, 3840])\n"
71+
]
72+
}
73+
],
74+
"source": [
75+
"# A single sample of random waveform\n",
76+
"wav = torch.rand(1, 16000 * 10)\n",
77+
"\n",
78+
"# Encode with M2D\n",
79+
"with torch.no_grad():\n",
80+
" embeddings = model(wav)\n",
81+
"\n",
82+
"# The output embeddings has a shape of [Batch, Frame, Dimension]\n",
83+
"print(embeddings.shape) # --> torch.Size([1, 63, 3840])"
84+
]
85+
},
86+
{
87+
"cell_type": "code",
88+
"execution_count": null,
89+
"metadata": {},
90+
"outputs": [],
91+
"source": []
92+
}
93+
],
94+
"metadata": {
95+
"kernelspec": {
96+
"display_name": "ar",
97+
"language": "python",
98+
"name": "python3"
99+
},
100+
"language_info": {
101+
"codemirror_mode": {
102+
"name": "ipython",
103+
"version": 3
104+
},
105+
"file_extension": ".py",
106+
"mimetype": "text/x-python",
107+
"name": "python",
108+
"nbconvert_exporter": "python",
109+
"pygments_lexer": "ipython3",
110+
"version": "3.9.18"
111+
}
112+
},
113+
"nbformat": 4,
114+
"nbformat_minor": 2
115+
}

0 commit comments

Comments
 (0)