diff --git a/data.py b/data.py index d9b24f3..090580f 100644 --- a/data.py +++ b/data.py @@ -8,11 +8,6 @@ import random from librosa import resample from librosa.effects import split -from scipy.stats import special_ortho_group -from scipy.io import wavfile -import scipy -import pandas as pd -import noisereduce as nr from tqdm import tqdm diff --git a/main.py b/main.py index 03c7acf..691473f 100644 --- a/main.py +++ b/main.py @@ -4,12 +4,11 @@ from wav2pos import wav2pos import torch.optim as optim from data import LibriSpeechLocations, DelaySimulatorDataset, remove_silence -import scipy import torch.utils.data as data_utils import importlib import argparse import os -from timm.optim import optim_factory +from timm.optim import param_groups_weight_decay from tqdm import tqdm from warmup_scheduler import GradualWarmupScheduler import datetime @@ -115,6 +114,15 @@ def log_print(*kwargs): mic_locs = np.stack([mic1, mic2, mic3, mic4, mic5, mic6]).transpose(1, 2, 0) +# # For more than 6 microphones (`num_mics` in `cfg.py`), you will need to add +# # more microphone positions. For example, you can add: +# mic7 = np.random.uniform( +# [0, 0, room_len_z/2], [room_len_x, room_len_y, room_len_z/2], coord_size) +# mic8 = np.random.uniform( +# [room_len_x/2, 0, 0], [room_len_x/2, room_len_y, room_len_z], coord_size) + +# # And update `mic_locs`: +# mic_locs = np.stack([mic1, mic2, mic3, mic4, mic5, mic6, mic7, mic8]).transpose(1, 2, 0) log_print("Data prep started...") data_set = LibriSpeechLocations(source_locs, mic_locs, split="test-clean", random_source_pos=True, @@ -223,7 +231,7 @@ def log_print(*kwargs): # Create optimizer no_weight_decay_list = {'norm', 'enc_audio_modality', 'enc_loc_modality', 'dec_audio_modality', 'dec_loc_modality', 'pos_embed', 'decoder_pos_embed', 'mask_token'} -param_groups = optim_factory.param_groups_weight_decay( +param_groups = param_groups_weight_decay( model, cfg.wd, no_weight_decay_list) optimizer = optim.AdamW(param_groups, lr=cfg.lr) scheduler = optim.lr_scheduler.CosineAnnealingLR( @@ -336,12 +344,12 @@ def log_print(*kwargs): scheduler.step() outstr = 'Train epoch, %d, audio loss, %.6f, loc loss, %.6f, tdoa loss, %.6f, loc MAE [cm], %.6f, loc acc, %.6f, lr, %.6f' % (e, - curr_loss_audio, - curr_loss_locs, - curr_loss_tdoas, - curr_mae * 100.0, - curr_acc, - optimizer.param_groups[0]['lr']) + curr_loss_audio, + curr_loss_locs, + curr_loss_tdoas, + curr_mae * 100.0, + curr_acc, + optimizer.param_groups[0]['lr']) log_string(outstr+'\n') diff --git a/ngcc/dnn_models.py b/ngcc/dnn_models.py index 1879cb8..2ace327 100644 --- a/ngcc/dnn_models.py +++ b/ngcc/dnn_models.py @@ -1,8 +1,8 @@ -''' +""" This file contains the implementation of SincNet, by Mirco Ravanelli and Yoshua Bengio Circular padding has been added before each convolution. Source: https://github.com/mravanelli/SincNet -''' +""" import numpy as np import torch @@ -18,13 +18,18 @@ def flip(x, dim): dim = x.dim() + dim if dim < 0 else dim x = x.contiguous() x = x.view(-1, *xsize[dim:]) - x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, - -1, -1), ('cpu', 'cuda')[x.is_cuda])().long(), :] + x = x.view(x.size(0), x.size(1), -1)[ + :, + getattr( + torch.arange(x.size(1) - 1, -1, -1), ("cpu", "cuda")[x.is_cuda] + )().long(), + :, + ] return x.view(xsize) def sinc(band, t_right): - y_right = torch.sin(2*math.pi*band*t_right)/(2*math.pi*band*t_right) + y_right = torch.sin(2 * math.pi * band * t_right) / (2 * math.pi * band * t_right) y_left = flip(y_right, 0) y = torch.cat([y_left, Variable(torch.ones(1)).cuda(), y_right]) @@ -62,16 +67,29 @@ def to_mel(hz): def to_hz(mel): return 700 * (10 ** (mel / 2595) - 1) - def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1, - stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=50, min_band_hz=50): - + def __init__( + self, + out_channels, + kernel_size, + sample_rate=16000, + in_channels=1, + stride=1, + padding=0, + dilation=1, + bias=False, + groups=1, + min_low_hz=50, + min_band_hz=50, + ): super(SincConv_fast, self).__init__() if in_channels != 1: - #msg = (f'SincConv only support one input channel ' + # msg = (f'SincConv only support one input channel ' # f'(here, in_channels = {in_channels:d}).') - msg = "SincConv only support one input channel (here, in_channels = {%i})" % ( - in_channels) + msg = ( + "SincConv only support one input channel (here, in_channels = {%i})" + % (in_channels) + ) raise ValueError(msg) self.out_channels = out_channels @@ -79,16 +97,16 @@ def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1, # Forcing the filters to be odd (i.e, perfectly symmetrics) if kernel_size % 2 == 0: - self.kernel_size = self.kernel_size+1 + self.kernel_size = self.kernel_size + 1 self.stride = stride self.padding = padding self.dilation = dilation if bias: - raise ValueError('SincConv does not support bias.') + raise ValueError("SincConv does not support bias.") if groups > 1: - raise ValueError('SincConv does not support groups.') + raise ValueError("SincConv does not support groups.") self.sample_rate = sample_rate self.min_low_hz = min_low_hz @@ -98,9 +116,9 @@ def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1, low_hz = 30 high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz) - mel = np.linspace(self.to_mel(low_hz), - self.to_mel(high_hz), - self.out_channels + 1) + mel = np.linspace( + self.to_mel(low_hz), self.to_mel(high_hz), self.out_channels + 1 + ) hz = self.to_hz(mel) # filter lower frequency (out_channels, 1) @@ -110,16 +128,17 @@ def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1, self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1)) # Hamming window - #self.window_ = torch.hamming_window(self.kernel_size) + # self.window_ = torch.hamming_window(self.kernel_size) # computing only half of the window - n_lin = torch.linspace(0, (self.kernel_size/2)-1, - steps=int((self.kernel_size/2))) - self.window_ = 0.54-0.46*torch.cos(2*math.pi*n_lin/self.kernel_size) + n_lin = torch.linspace( + 0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2)) + ) + self.window_ = 0.54 - 0.46 * torch.cos(2 * math.pi * n_lin / self.kernel_size) # (1, kernel_size/2) n = (self.kernel_size - 1) / 2.0 # Due to symmetry, I only need half of the time axes - self.n_ = 2*math.pi*torch.arange(-n, 0).view(1, -1) / self.sample_rate + self.n_ = 2 * math.pi * torch.arange(-n, 0).view(1, -1) / self.sample_rate def forward(self, waveforms): """ @@ -139,91 +158,102 @@ def forward(self, waveforms): low = self.min_low_hz + torch.abs(self.low_hz_) - high = torch.clamp(low + self.min_band_hz - + torch.abs(self.band_hz_), self.min_low_hz, self.sample_rate/2) - band = (high-low)[:, 0] + high = torch.clamp( + low + self.min_band_hz + torch.abs(self.band_hz_), + self.min_low_hz, + self.sample_rate / 2, + ) + band = (high - low)[:, 0] f_times_t_low = torch.matmul(low, self.n_) f_times_t_high = torch.matmul(high, self.n_) # Equivalent of Eq.4 of the reference paper (SPEAKER RECOGNITION FROM RAW WAVEFORM WITH SINCNET). I just have expanded the sinc and simplified the terms. This way I avoid several useless computations. - band_pass_left = ((torch.sin(f_times_t_high) - - torch.sin(f_times_t_low))/(self.n_/2))*self.window_ - band_pass_center = 2*band.view(-1, 1) + band_pass_left = ( + (torch.sin(f_times_t_high) - torch.sin(f_times_t_low)) / (self.n_ / 2) + ) * self.window_ + band_pass_center = 2 * band.view(-1, 1) band_pass_right = torch.flip(band_pass_left, dims=[1]) band_pass = torch.cat( - [band_pass_left, band_pass_center, band_pass_right], dim=1) + [band_pass_left, band_pass_center, band_pass_right], dim=1 + ) - band_pass = band_pass / (2*band[:, None]) + band_pass = band_pass / (2 * band[:, None]) - self.filters = (band_pass).view( - self.out_channels, 1, self.kernel_size) + self.filters = (band_pass).view(self.out_channels, 1, self.kernel_size) - return F.conv1d(waveforms, self.filters, stride=self.stride, - padding=self.padding, dilation=self.dilation, - bias=None, groups=1) + return F.conv1d( + waveforms, + self.filters, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + bias=None, + groups=1, + ) class sinc_conv(nn.Module): - def __init__(self, N_filt, Filt_dim, fs): super(sinc_conv, self).__init__() # Mel Initialization of the filterbanks low_freq_mel = 80 - high_freq_mel = (2595 * np.log10(1 + (fs / 2) / 700) - ) # Convert Hz to Mel + high_freq_mel = 2595 * np.log10(1 + (fs / 2) / 700) # Convert Hz to Mel # Equally spaced in Mel scale mel_points = np.linspace(low_freq_mel, high_freq_mel, N_filt) - f_cos = (700 * (10**(mel_points / 2595) - 1)) # Convert Mel to Hz + f_cos = 700 * (10 ** (mel_points / 2595) - 1) # Convert Mel to Hz b1 = np.roll(f_cos, 1) b2 = np.roll(f_cos, -1) b1[0] = 30 - b2[-1] = (fs/2)-100 + b2[-1] = (fs / 2) - 100 - self.freq_scale = fs*1.0 - self.filt_b1 = nn.Parameter(torch.from_numpy(b1/self.freq_scale)) - self.filt_band = nn.Parameter( - torch.from_numpy((b2-b1)/self.freq_scale)) + self.freq_scale = fs * 1.0 + self.filt_b1 = nn.Parameter(torch.from_numpy(b1 / self.freq_scale)) + self.filt_band = nn.Parameter(torch.from_numpy((b2 - b1) / self.freq_scale)) self.N_filt = N_filt self.Filt_dim = Filt_dim self.fs = fs def forward(self, x): - filters = Variable(torch.zeros((self.N_filt, self.Filt_dim))).cuda() N = self.Filt_dim - t_right = Variable(torch.linspace( - 1, (N-1)/2, steps=int((N-1)/2))/self.fs).cuda() + t_right = Variable( + torch.linspace(1, (N - 1) / 2, steps=int((N - 1) / 2)) / self.fs + ).cuda() min_freq = 50.0 min_band = 50.0 - filt_beg_freq = torch.abs(self.filt_b1)+min_freq/self.freq_scale - filt_end_freq = filt_beg_freq + \ - (torch.abs(self.filt_band)+min_band/self.freq_scale) + filt_beg_freq = torch.abs(self.filt_b1) + min_freq / self.freq_scale + filt_end_freq = filt_beg_freq + ( + torch.abs(self.filt_band) + min_band / self.freq_scale + ) n = torch.linspace(0, N, steps=N) # Filter window (hamming) - window = 0.54-0.46*torch.cos(2*math.pi*n/N) + window = 0.54 - 0.46 * torch.cos(2 * math.pi * n / N) window = Variable(window.float().cuda()) for i in range(self.N_filt): - - low_pass1 = 2 * \ - filt_beg_freq[i].float()*sinc(filt_beg_freq[i].float() - * self.freq_scale, t_right) - low_pass2 = 2 * \ - filt_end_freq[i].float()*sinc(filt_end_freq[i].float() - * self.freq_scale, t_right) - band_pass = (low_pass2-low_pass1) - - band_pass = band_pass/torch.max(band_pass) - - filters[i, :] = band_pass.cuda()*window + low_pass1 = ( + 2 + * filt_beg_freq[i].float() + * sinc(filt_beg_freq[i].float() * self.freq_scale, t_right) + ) + low_pass2 = ( + 2 + * filt_end_freq[i].float() + * sinc(filt_end_freq[i].float() * self.freq_scale, t_right) + ) + band_pass = low_pass2 - low_pass1 + + band_pass = band_pass / torch.max(band_pass) + + filters[i, :] = band_pass.cuda() * window out = F.conv1d(x, filters.view(self.N_filt, 1, self.Filt_dim)) @@ -231,31 +261,29 @@ def forward(self, x): def act_fun(act_type): + if act_type == "relu": + return nn.ReLU() - if act_type == "relu": - return nn.ReLU() + if act_type == "tanh": + return nn.Tanh() - if act_type == "tanh": - return nn.Tanh() + if act_type == "sigmoid": + return nn.Sigmoid() - if act_type == "sigmoid": - return nn.Sigmoid() + if act_type == "leaky_relu": + return nn.LeakyReLU(0.2) - if act_type == "leaky_relu": - return nn.LeakyReLU(0.2) + if act_type == "elu": + return nn.ELU() - if act_type == "elu": - return nn.ELU() + if act_type == "softmax": + return nn.LogSoftmax(dim=1) - if act_type == "softmax": - return nn.LogSoftmax(dim=1) - - if act_type == "linear": - return nn.LeakyReLU(1) # initializzed like this, but not used in forward! + if act_type == "linear": + return nn.LeakyReLU(1) # initializzed like this, but not used in forward! class LayerNorm(nn.Module): - def __init__(self, features, eps=1e-6): super(LayerNorm, self).__init__() self.gamma = nn.Parameter(torch.ones(features)) @@ -272,14 +300,14 @@ class MLP(nn.Module): def __init__(self, options): super(MLP, self).__init__() - self.input_dim = int(options['input_dim']) - self.fc_lay = options['fc_lay'] - self.fc_drop = options['fc_drop'] - self.fc_use_batchnorm = options['fc_use_batchnorm'] - self.fc_use_laynorm = options['fc_use_laynorm'] - self.fc_use_laynorm_inp = options['fc_use_laynorm_inp'] - self.fc_use_batchnorm_inp = options['fc_use_batchnorm_inp'] - self.fc_act = options['fc_act'] + self.input_dim = int(options["input_dim"]) + self.fc_lay = options["fc_lay"] + self.fc_drop = options["fc_drop"] + self.fc_use_batchnorm = options["fc_use_batchnorm"] + self.fc_use_laynorm = options["fc_use_laynorm"] + self.fc_use_laynorm_inp = options["fc_use_laynorm_inp"] + self.fc_use_batchnorm_inp = options["fc_use_batchnorm_inp"] + self.fc_act = options["fc_act"] self.wx = nn.ModuleList([]) self.bn = nn.ModuleList([]) @@ -289,11 +317,11 @@ def __init__(self, options): # input layer normalization if self.fc_use_laynorm_inp: - self.ln0 = LayerNorm(self.input_dim) + self.ln0 = LayerNorm(self.input_dim) # input batch normalization if self.fc_use_batchnorm_inp: - self.bn0 = nn.BatchNorm1d([self.input_dim], momentum=0.05) + self.bn0 = nn.BatchNorm1d([self.input_dim], momentum=0.05) self.N_fc_lay = len(self.fc_lay) @@ -302,173 +330,194 @@ def __init__(self, options): # Initialization of hidden layers for i in range(self.N_fc_lay): + # dropout + self.drop.append(nn.Dropout(p=self.fc_drop[i])) - # dropout - self.drop.append(nn.Dropout(p=self.fc_drop[i])) + # activation + self.act.append(act_fun(self.fc_act[i])) - # activation - self.act.append(act_fun(self.fc_act[i])) + add_bias = True - add_bias = True + # layer norm initialization + self.ln.append(LayerNorm(self.fc_lay[i])) + self.bn.append(nn.BatchNorm1d(self.fc_lay[i], momentum=0.05)) - # layer norm initialization - self.ln.append(LayerNorm(self.fc_lay[i])) - self.bn.append(nn.BatchNorm1d(self.fc_lay[i], momentum=0.05)) + if self.fc_use_laynorm[i] or self.fc_use_batchnorm[i]: + add_bias = False - if self.fc_use_laynorm[i] or self.fc_use_batchnorm[i]: - add_bias = False + # Linear operations + self.wx.append(nn.Linear(current_input, self.fc_lay[i], bias=add_bias)) - # Linear operations - self.wx.append( - nn.Linear(current_input, self.fc_lay[i], bias=add_bias)) + # weight initialization + self.wx[i].weight = torch.nn.Parameter( + torch.Tensor(self.fc_lay[i], current_input).uniform_( + -np.sqrt(0.01 / (current_input + self.fc_lay[i])), + np.sqrt(0.01 / (current_input + self.fc_lay[i])), + ) + ) + self.wx[i].bias = torch.nn.Parameter(torch.zeros(self.fc_lay[i])) - # weight initialization - self.wx[i].weight = torch.nn.Parameter(torch.Tensor(self.fc_lay[i], current_input).uniform_( - -np.sqrt(0.01/(current_input+self.fc_lay[i])), np.sqrt(0.01/(current_input+self.fc_lay[i])))) - self.wx[i].bias = torch.nn.Parameter(torch.zeros(self.fc_lay[i])) - - current_input = self.fc_lay[i] + current_input = self.fc_lay[i] def forward(self, x): + # Applying Layer/Batch Norm + if bool(self.fc_use_laynorm_inp): + x = self.ln0((x)) - # Applying Layer/Batch Norm - if bool(self.fc_use_laynorm_inp): - x = self.ln0((x)) - - if bool(self.fc_use_batchnorm_inp): - x = self.bn0((x)) - - for i in range(self.N_fc_lay): - - if self.fc_act[i] != 'linear': + if bool(self.fc_use_batchnorm_inp): + x = self.bn0((x)) - if self.fc_use_laynorm[i]: - x = self.drop[i](self.act[i](self.ln[i](self.wx[i](x)))) + for i in range(self.N_fc_lay): + if self.fc_act[i] != "linear": + if self.fc_use_laynorm[i]: + x = self.drop[i](self.act[i](self.ln[i](self.wx[i](x)))) - if self.fc_use_batchnorm[i]: - x = self.drop[i](self.act[i](self.bn[i](self.wx[i](x)))) + if self.fc_use_batchnorm[i]: + x = self.drop[i](self.act[i](self.bn[i](self.wx[i](x)))) - if self.fc_use_batchnorm[i] == False and self.fc_use_laynorm[i] == False: - x = self.drop[i](self.act[i](self.wx[i](x))) + if not self.fc_use_batchnorm[i] and not self.fc_use_laynorm[i]: + x = self.drop[i](self.act[i](self.wx[i](x))) - else: - if self.fc_use_laynorm[i]: - x = self.drop[i](self.ln[i](self.wx[i](x))) + else: + if self.fc_use_laynorm[i]: + x = self.drop[i](self.ln[i](self.wx[i](x))) - if self.fc_use_batchnorm[i]: - x = self.drop[i](self.bn[i](self.wx[i](x))) + if self.fc_use_batchnorm[i]: + x = self.drop[i](self.bn[i](self.wx[i](x))) - if self.fc_use_batchnorm[i] == False and self.fc_use_laynorm[i] == False: - x = self.drop[i](self.wx[i](x)) + if not self.fc_use_batchnorm[i] and not self.fc_use_laynorm[i]: + x = self.drop[i](self.wx[i](x)) - return x + return x class SincNet(nn.Module): - def __init__(self, options): - super(SincNet, self).__init__() - - self.cnn_N_filt = options['cnn_N_filt'] - self.cnn_len_filt = options['cnn_len_filt'] - self.cnn_max_pool_len = options['cnn_max_pool_len'] + super(SincNet, self).__init__() - self.cnn_act = options['cnn_act'] - self.cnn_drop = options['cnn_drop'] + self.cnn_N_filt = options["cnn_N_filt"] + self.cnn_len_filt = options["cnn_len_filt"] + self.cnn_max_pool_len = options["cnn_max_pool_len"] - self.cnn_use_laynorm = options['cnn_use_laynorm'] - self.cnn_use_batchnorm = options['cnn_use_batchnorm'] - self.cnn_use_laynorm_inp = options['cnn_use_laynorm_inp'] - self.cnn_use_batchnorm_inp = options['cnn_use_batchnorm_inp'] + self.cnn_act = options["cnn_act"] + self.cnn_drop = options["cnn_drop"] - self.input_dim = int(options['input_dim']) + self.cnn_use_laynorm = options["cnn_use_laynorm"] + self.cnn_use_batchnorm = options["cnn_use_batchnorm"] + self.cnn_use_laynorm_inp = options["cnn_use_laynorm_inp"] + self.cnn_use_batchnorm_inp = options["cnn_use_batchnorm_inp"] - self.fs = options['fs'] + self.input_dim = int(options["input_dim"]) - self.N_cnn_lay = len(options['cnn_N_filt']) - self.conv = nn.ModuleList([]) - self.bn = nn.ModuleList([]) - self.ln = nn.ModuleList([]) - self.act = nn.ModuleList([]) - self.drop = nn.ModuleList([]) - self.use_sinc = options['use_sinc'] + self.fs = options["fs"] - if self.cnn_use_laynorm_inp: - self.ln0 = LayerNorm(self.input_dim) + self.N_cnn_lay = len(options["cnn_N_filt"]) + self.conv = nn.ModuleList([]) + self.bn = nn.ModuleList([]) + self.ln = nn.ModuleList([]) + self.act = nn.ModuleList([]) + self.drop = nn.ModuleList([]) + self.use_sinc = options["use_sinc"] - if self.cnn_use_batchnorm_inp: - self.bn0 = nn.BatchNorm1d([self.input_dim], momentum=0.05) + if self.cnn_use_laynorm_inp: + self.ln0 = LayerNorm(self.input_dim) - current_input = self.input_dim + if self.cnn_use_batchnorm_inp: + self.bn0 = nn.BatchNorm1d([self.input_dim], momentum=0.05) - for i in range(self.N_cnn_lay): + current_input = self.input_dim - N_filt = int(self.cnn_N_filt[i]) - len_filt = int(self.cnn_len_filt[i]) + for i in range(self.N_cnn_lay): + N_filt = int(self.cnn_N_filt[i]) + # len_filt = int(self.cnn_len_filt[i]) - # dropout - self.drop.append(nn.Dropout(p=self.cnn_drop[i])) + # dropout + self.drop.append(nn.Dropout(p=self.cnn_drop[i])) - # activation - self.act.append(act_fun(self.cnn_act[i])) + # activation + self.act.append(act_fun(self.cnn_act[i])) - # layer norm initialization - #self.ln.append(LayerNorm([N_filt,int((current_input-self.cnn_len_filt[i]+1)/self.cnn_max_pool_len[i])])) + # layer norm initialization + # self.ln.append(LayerNorm([N_filt,int((current_input-self.cnn_len_filt[i]+1)/self.cnn_max_pool_len[i])])) - self.bn.append(nn.BatchNorm1d(N_filt, momentum=0.05)) + self.bn.append(nn.BatchNorm1d(N_filt, momentum=0.05)) - if i == 0: - if self.use_sinc: - self.conv.append(SincConv_fast( - self.cnn_N_filt[0], self.cnn_len_filt[0], self.fs)) - else: - self.conv.append( - nn.Conv1d(1, self.cnn_N_filt[i], self.cnn_len_filt[i])) + if i == 0: + if self.use_sinc: + self.conv.append( + SincConv_fast(self.cnn_N_filt[0], self.cnn_len_filt[0], self.fs) + ) + else: + self.conv.append( + nn.Conv1d(1, self.cnn_N_filt[i], self.cnn_len_filt[i]) + ) - else: - self.conv.append( - nn.Conv1d(self.cnn_N_filt[i-1], self.cnn_N_filt[i], self.cnn_len_filt[i])) + else: + self.conv.append( + nn.Conv1d( + self.cnn_N_filt[i - 1], self.cnn_N_filt[i], self.cnn_len_filt[i] + ) + ) - current_input = int( - (current_input-self.cnn_len_filt[i]+1)/self.cnn_max_pool_len[i]) + current_input = int( + (current_input - self.cnn_len_filt[i] + 1) / self.cnn_max_pool_len[i] + ) - self.out_dim = current_input*N_filt + self.out_dim = current_input * N_filt def forward(self, x): - batch = x.shape[0] - seq_len = x.shape[-1] - - if bool(self.cnn_use_laynorm_inp): - x = self.ln0((x)) - - if bool(self.cnn_use_batchnorm_inp): - x = self.bn0((x)) - - x = x.view(batch, 1, seq_len) - - for i in range(self.N_cnn_lay): - - s = x.shape[2] - padding = get_pad( - size=s, kernel_size=self.cnn_len_filt[i], stride=1, dilation=1) - x = F.pad(x, pad=padding, mode='circular') - - if self.cnn_use_laynorm[i]: - if i == 0: - x = self.drop[i](self.act[i](self.ln[i](F.max_pool1d( - torch.abs(self.conv[i](x)), self.cnn_max_pool_len[i])))) - else: - x = self.drop[i](self.act[i](self.ln[i]( - F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i])))) - - if self.cnn_use_batchnorm[i]: - x = self.drop[i](self.act[i](self.bn[i]( - F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i])))) - - if self.cnn_use_batchnorm[i] == False and self.cnn_use_laynorm[i] == False: - x = self.drop[i](self.act[i](F.max_pool1d( - self.conv[i](x), self.cnn_max_pool_len[i]))) - - #x = x.view(batch,-1) - - return x + batch = x.shape[0] + seq_len = x.shape[-1] + + if bool(self.cnn_use_laynorm_inp): + x = self.ln0((x)) + + if bool(self.cnn_use_batchnorm_inp): + x = self.bn0((x)) + + x = x.view(batch, 1, seq_len) + + for i in range(self.N_cnn_lay): + s = x.shape[2] + padding = get_pad( + size=s, kernel_size=self.cnn_len_filt[i], stride=1, dilation=1 + ) + x = F.pad(x, pad=padding, mode="circular") + + if self.cnn_use_laynorm[i]: + if i == 0: + x = self.drop[i]( + self.act[i]( + self.ln[i]( + F.max_pool1d( + torch.abs(self.conv[i](x)), self.cnn_max_pool_len[i] + ) + ) + ) + ) + else: + x = self.drop[i]( + self.act[i]( + self.ln[i]( + F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i]) + ) + ) + ) + + if self.cnn_use_batchnorm[i]: + x = self.drop[i]( + self.act[i]( + self.bn[i]( + F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i]) + ) + ) + ) + + if not self.cnn_use_batchnorm[i] and not self.cnn_use_laynorm[i]: + x = self.drop[i]( + self.act[i](F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i])) + ) + + # x = x.view(batch,-1) + + return x diff --git a/ngcc/model.py b/ngcc/model.py index 0f8913a..695f6c0 100644 --- a/ngcc/model.py +++ b/ngcc/model.py @@ -9,12 +9,12 @@ class GCC(nn.Module): - def __init__(self, max_tau=None, dim=2, filt='phat', epsilon=0.001, beta=None): + def __init__(self, max_tau=None, dim=2, filt="phat", epsilon=0.001, beta=None): super().__init__() - ''' GCC implementation based on Knapp and Carter, + """ GCC implementation based on Knapp and Carter, "The Generalized Correlation Method for Estimation of Time Delay", - IEEE Trans. Acoust., Speech, Signal Processing, August, 1976 ''' + IEEE Trans. Acoust., Speech, Signal Processing, August, 1976 """ self.max_tau = max_tau self.dim = dim @@ -23,10 +23,9 @@ def __init__(self, max_tau=None, dim=2, filt='phat', epsilon=0.001, beta=None): self.beta = beta def forward(self, x, y, window=None): - n = x.shape[-1] + y.shape[-1] - if window == 'hann': + if window == "hann": window = torch.hann_window(x.shape[-1], device=x.device) x = x * window y = y * window @@ -36,17 +35,16 @@ def forward(self, x, y, window=None): Y = torch.fft.rfft(y, n=n) Gxy = X * torch.conj(Y) - if self.filt == 'phat': + if self.filt == "phat": phi = 1 / (torch.abs(Gxy) + self.epsilon) else: - raise ValueError('Unsupported filter function') + raise ValueError("Unsupported filter function") if self.beta is not None: cc = [] for i in range(self.beta.shape[0]): - cc.append(torch.fft.irfft( - Gxy * torch.pow(phi, self.beta[i]), n)) + cc.append(torch.fft.irfft(Gxy * torch.pow(phi, self.beta[i]), n)) cc = torch.cat(cc, dim=1) @@ -58,20 +56,26 @@ def forward(self, x, y, window=None): max_shift = np.minimum(self.max_tau, int(max_shift)) if self.dim == 2: - cc = torch.cat((cc[:, -max_shift:], cc[:, :max_shift+1]), dim=-1) + cc = torch.cat((cc[:, -max_shift:], cc[:, : max_shift + 1]), dim=-1) elif self.dim == 3: - cc = torch.cat( - (cc[:, :, -max_shift:], cc[:, :, :max_shift+1]), dim=-1) + cc = torch.cat((cc[:, :, -max_shift:], cc[:, :, : max_shift + 1]), dim=-1) return cc class NGCCPHAT(nn.Module): - def __init__(self, max_tau=42, head='classifier', use_sinc=True, - sig_len=2048, num_channels=128, fs=16000): + def __init__( + self, + max_tau=42, + head="classifier", + use_sinc=True, + sig_len=2048, + num_channels=128, + fs=16000, + ): super().__init__() - ''' + """ Neural GCC-PHAT with SincNet backbone arguments: @@ -81,48 +85,55 @@ def __init__(self, max_tau=42, head='classifier', use_sinc=True, sig_len - length of input signal n_channel - number of gcc correlation channels to use fs - sampling frequency - ''' + """ self.max_tau = max_tau self.head = head - sincnet_params = {'input_dim': sig_len, - 'fs': fs, - 'cnn_N_filt': [128, 128, 128, num_channels], - 'cnn_len_filt': [1023, 11, 9, 7], - 'cnn_max_pool_len': [1, 1, 1, 1], - 'cnn_use_laynorm_inp': False, - 'cnn_use_batchnorm_inp': False, - 'cnn_use_laynorm': [False, False, False, False], - 'cnn_use_batchnorm': [True, True, True, True], - 'cnn_act': ['leaky_relu', 'leaky_relu', 'leaky_relu', 'linear'], - 'cnn_drop': [0.0, 0.0, 0.0, 0.0], - 'use_sinc': use_sinc, - } + sincnet_params = { + "input_dim": sig_len, + "fs": fs, + "cnn_N_filt": [128, 128, 128, num_channels], + "cnn_len_filt": [1023, 11, 9, 7], + "cnn_max_pool_len": [1, 1, 1, 1], + "cnn_use_laynorm_inp": False, + "cnn_use_batchnorm_inp": False, + "cnn_use_laynorm": [False, False, False, False], + "cnn_use_batchnorm": [True, True, True, True], + "cnn_act": ["leaky_relu", "leaky_relu", "leaky_relu", "linear"], + "cnn_drop": [0.0, 0.0, 0.0, 0.0], + "use_sinc": use_sinc, + } self.backbone = SincNet(sincnet_params) self.mlp_kernels = [11, 9, 7] self.channels = [num_channels, 128, 128, 128] self.final_kernel = [5] - self.gcc = GCC(max_tau=self.max_tau, dim=3, filt='phat') + self.gcc = GCC(max_tau=self.max_tau, dim=3, filt="phat") - self.mlp = nn.ModuleList([nn.Sequential( - nn.Conv1d(self.channels[i], self.channels[i+1], kernel_size=k), - nn.BatchNorm1d(self.channels[i+1]), - nn.LeakyReLU(0.2), - nn.Dropout(0.5)) for i, k in enumerate(self.mlp_kernels)]) + self.mlp = nn.ModuleList( + [ + nn.Sequential( + nn.Conv1d(self.channels[i], self.channels[i + 1], kernel_size=k), + nn.BatchNorm1d(self.channels[i + 1]), + nn.LeakyReLU(0.2), + nn.Dropout(0.5), + ) + for i, k in enumerate(self.mlp_kernels) + ] + ) self.final_conv = nn.Conv1d(128, 1, kernel_size=self.final_kernel) - if head == 'regression': + if head == "regression": self.reg = nn.Sequential( - nn.BatchNorm1d(2 * self.max_tau + 1), - nn.LeakyReLU(0.2), - nn.Linear(2 * self.max_tau + 1, 1)) + nn.BatchNorm1d(2 * self.max_tau + 1), + nn.LeakyReLU(0.2), + nn.Linear(2 * self.max_tau + 1, 1), + ) def forward(self, x1, x2): - batch_size = x1.shape[0] y1 = self.backbone(x1) @@ -133,63 +144,79 @@ def forward(self, x1, x2): for k, layer in enumerate(self.mlp): s = cc.shape[2] padding = get_pad( - size=s, kernel_size=self.mlp_kernels[k], stride=1, dilation=1) - cc = F.pad(cc, pad=padding, mode='constant') + size=s, kernel_size=self.mlp_kernels[k], stride=1, dilation=1 + ) + cc = F.pad(cc, pad=padding, mode="constant") cc = layer(cc) s = cc.shape[2] - padding = get_pad( - size=s, kernel_size=self.final_kernel, stride=1, dilation=1) - cc = F.pad(cc, pad=padding, mode='constant') + padding = get_pad(size=s, kernel_size=self.final_kernel, stride=1, dilation=1) + cc = F.pad(cc, pad=padding, mode="constant") cc = self.final_conv(cc).reshape([batch_size, -1]) - if self.head == 'regression': + if self.head == "regression": cc = self.reg(cc).squeeze() return cc - + class masked_NGCCPHAT(nn.Module): - def __init__(self, snr_interval, max_tau, num_mics, head='regression', use_sinc=True, - sig_len=2048, num_channels=128, fs=16000): + def __init__( + self, + snr_interval, + max_tau, + num_mics, + head="regression", + use_sinc=True, + sig_len=2048, + num_channels=128, + fs=16000, + ): super().__init__() - self.ngcc = NGCCPHAT(max_tau=max_tau, head=head, use_sinc=use_sinc, - sig_len=sig_len, num_channels=num_channels, fs=fs) - + self.ngcc = NGCCPHAT( + max_tau=max_tau, + head=head, + use_sinc=use_sinc, + sig_len=sig_len, + num_channels=num_channels, + fs=fs, + ) + self.gcc = GCC(max_tau=max_tau) - + self.max_tau = max_tau self.num_mics = num_mics self.head = head self.c = 343 - self.transform = AddColoredNoise(p=1.0, min_snr_in_db=snr_interval[0], - max_snr_in_db=snr_interval[1], - sample_rate=fs, mode="per_channel", - p_mode="per_channel") + self.transform = AddColoredNoise( + p=1.0, + min_snr_in_db=snr_interval[0], + max_snr_in_db=snr_interval[1], + sample_rate=fs, + mode="per_channel", + p_mode="per_channel", + ) - if head == 'regression': + if head == "regression": self.loss_fn = nn.MSELoss() - elif head == 'classifier': + elif head == "classifier": self.loss_fn = nn.CrossEntropyLoss() else: - raise ValueError('Please select a valid model head') - + raise ValueError("Please select a valid model head") + def masking(self, x, ids_keep): - N, L, D = x.shape # batch, length, dim mask = torch.ones([N, L], device=x.device) replace = torch.zeros(ids_keep.size(), device=x.device) mask = mask.scatter(dim=1, index=ids_keep, src=replace) - x_masked = torch.gather( - x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) return x_masked, mask - - def forward_loss(self, tdoa_pred, tdoa_label): + def forward_loss(self, tdoa_pred, tdoa_label): return self.loss_fn(tdoa_pred, tdoa_label) - + def random_select(self, audio, tdoas): # randomly select two microphones bs = audio.shape[0] @@ -209,7 +236,7 @@ def random_select(self, audio, tdoas): tdoa = tdoas[b, m1, m2] # convert tdoa to categorical label - if self.head == 'classifier': + if self.head == "classifier": tdoa = tdoa + self.max_tau tdoa = tdoa.long() else: @@ -225,19 +252,18 @@ def random_select(self, audio, tdoas): return x1, x2, labels - def forward(self, audio, tdoas, mode='test'): - + def forward(self, audio, tdoas, mode="test"): audio = audio.squeeze(1) - - if mode == 'train': + + if mode == "train": audio = self.transform(audio) - + x1, x2, label = self.random_select(audio, tdoas) y = self.ngcc(x1, x2) loss_tdoa = self.forward_loss(y, label) - if self.head == 'regression': + if self.head == "regression": pred_tdoa = y else: shift_gcc = torch.argmax(y, dim=-1) @@ -245,9 +271,7 @@ def forward(self, audio, tdoas, mode='test'): return loss_tdoa, pred_tdoa - def get_features(self, audio, ids_keep=None, normalize=False): - cc = [] audio = audio.squeeze(1) B, N, L = audio.shape @@ -255,15 +279,14 @@ def get_features(self, audio, ids_keep=None, normalize=False): x = self.ngcc.backbone(x) _, C, _ = x.shape x = x.view(B, N, C, L) - + if ids_keep == "all": idx_start = 0 else: idx_start = 1 for m1 in range(idx_start, N): - for m2 in range(m1+1, N): - + for m2 in range(m1 + 1, N): y1 = x[:, m1, :, :] y2 = x[:, m2, :, :] cc1 = self.ngcc.gcc(y1, y2) @@ -279,14 +302,16 @@ def get_features(self, audio, ids_keep=None, normalize=False): for k, layer in enumerate(self.ngcc.mlp): s = cc.shape[2] padding = get_pad( - size=s, kernel_size=self.ngcc.mlp_kernels[k], stride=1, dilation=1) - cc = F.pad(cc, pad=padding, mode='constant') + size=s, kernel_size=self.ngcc.mlp_kernels[k], stride=1, dilation=1 + ) + cc = F.pad(cc, pad=padding, mode="constant") cc = layer(cc) s = cc.shape[2] padding = get_pad( - size=s, kernel_size=self.ngcc.final_kernel, stride=1, dilation=1) - cc = F.pad(cc, pad=padding, mode='constant') + size=s, kernel_size=self.ngcc.final_kernel, stride=1, dilation=1 + ) + cc = F.pad(cc, pad=padding, mode="constant") cc = self.ngcc.final_conv(cc) _, C, L = cc.shape @@ -295,11 +320,10 @@ def get_features(self, audio, ids_keep=None, normalize=False): if normalize: cc /= cc.max(dim=-1, keepdims=True)[0] - features = cc.squeeze(2) # [B, N, L] (C=1) + features = cc.squeeze(2) # [B, N, L] (C=1) return features def get_gccphat_features(self, audio, ids_keep=None, normalize=False): - cc = [] audio = audio.squeeze(1) B, N, L = audio.shape @@ -313,11 +337,12 @@ def get_gccphat_features(self, audio, ids_keep=None, normalize=False): idx_start = 1 for m1 in range(idx_start, N): - for m2 in range(m1+1, N): - + for m2 in range(m1 + 1, N): y1 = x[:, m1, :, :] y2 = x[:, m2, :, :] - cc1 = self.ngcc.gcc(y1, y2, window=None) # hann window leads to instability + cc1 = self.ngcc.gcc( + y1, y2, window=None + ) # hann window leads to instability cc2 = torch.flip(cc1, dims=[-1]) cc.append(cc1) cc.append(cc2) @@ -334,11 +359,10 @@ def get_gccphat_features(self, audio, ids_keep=None, normalize=False): if normalize: cc /= cc.max(dim=-1, keepdims=True)[0] - features = cc.squeeze(2) # [B, N, L] (C=1) + features = cc.squeeze(2) # [B, N, L] (C=1) return features def get_one_feature(self, audio, i, j): - cc = [] audio = audio.squeeze(1) B, N, L = audio.shape @@ -346,7 +370,7 @@ def get_one_feature(self, audio, i, j): x = self.ngcc.backbone(x) _, C, _ = x.shape x = x.view(B, N, C, L) - + y1 = x[:, i, :, :] y2 = x[:, j, :, :] cc.append(self.ngcc.gcc(y1, y2)) @@ -359,18 +383,20 @@ def get_one_feature(self, audio, i, j): for k, layer in enumerate(self.ngcc.mlp): s = cc.shape[2] padding = get_pad( - size=s, kernel_size=self.ngcc.mlp_kernels[k], stride=1, dilation=1) - cc = F.pad(cc, pad=padding, mode='constant') + size=s, kernel_size=self.ngcc.mlp_kernels[k], stride=1, dilation=1 + ) + cc = F.pad(cc, pad=padding, mode="constant") cc = layer(cc) s = cc.shape[2] padding = get_pad( - size=s, kernel_size=self.ngcc.final_kernel, stride=1, dilation=1) - cc = F.pad(cc, pad=padding, mode='constant') + size=s, kernel_size=self.ngcc.final_kernel, stride=1, dilation=1 + ) + cc = F.pad(cc, pad=padding, mode="constant") cc = self.ngcc.final_conv(cc) _, C, L = cc.shape cc = cc.reshape(B, N, C, L) - features = cc.squeeze(2) # [B, N, L] (C=1) + features = cc.squeeze(2) # [B, N, L] (C=1) return features diff --git a/ngcc/torch_same_pad.py b/ngcc/torch_same_pad.py index 7721b53..954ec1e 100644 --- a/ngcc/torch_same_pad.py +++ b/ngcc/torch_same_pad.py @@ -4,13 +4,15 @@ import torch import torch.nn.functional as F -__all__ = ['get_pad', 'pad'] +__all__ = ["get_pad", "pad"] -def _calc_pad(size: int, - kernel_size: int = 3, - stride: int = 1, - dilation: int = 1): +def _calc_pad( + size: int, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, +): pad = (((size + stride - 1) // stride - 1) * stride + kernel_size - size) * dilation return pad // 2, pad - pad // 2 @@ -21,25 +23,34 @@ def _get_compressed(item: Union[int, Sequence[int]], index: int): return item -def get_pad(size: Union[int, Sequence[int]], - kernel_size: Union[int, Sequence[int]] = 3, - stride: Union[int, Sequence[int]] = 1, - dilation: Union[int, Sequence[int]] = 1): +def get_pad( + size: Union[int, Sequence[int]], + kernel_size: Union[int, Sequence[int]] = 3, + stride: Union[int, Sequence[int]] = 1, + dilation: Union[int, Sequence[int]] = 1, +): len_size = 1 if isinstance(size, collections.Sequence): len_size = len(size) pad = () for i in range(len_size): - pad = _calc_pad(size=_get_compressed(size, i), - kernel_size=_get_compressed(kernel_size, i), - stride=_get_compressed(stride, i), - dilation=_get_compressed(dilation, i)) + pad + pad = ( + _calc_pad( + size=_get_compressed(size, i), + kernel_size=_get_compressed(kernel_size, i), + stride=_get_compressed(stride, i), + dilation=_get_compressed(dilation, i), + ) + + pad + ) return pad -def pad(x: torch.Tensor, - size: Union[int, Sequence[int]], - kernel_size: Union[int, Sequence[int]] = 3, - stride: Union[int, Sequence[int]] = 1, - dilation: Union[int, Sequence[int]] = 1): - return F.pad(x, get_pad(size, kernel_size, stride, dilation)) \ No newline at end of file +def pad( + x: torch.Tensor, + size: Union[int, Sequence[int]], + kernel_size: Union[int, Sequence[int]] = 3, + stride: Union[int, Sequence[int]] = 1, + dilation: Union[int, Sequence[int]] = 1, +): + return F.pad(x, get_pad(size, kernel_size, stride, dilation)) diff --git a/utils.py b/utils.py index 1cbc670..fe6e87c 100644 --- a/utils.py +++ b/utils.py @@ -1,78 +1,96 @@ import numpy as np import torch + def create_mask(ids_keep, cfg, bs, device, n_patch_per_mic): - if cfg.random_masking == 'random': + if cfg.random_masking == "random": # randomly remove some mics, always remove mic 0 all_mics = np.expand_dims(np.arange(1, cfg.num_mics), axis=0).repeat(bs, axis=0) # randomly remove some audio, always remove audio from mic 0 for tdoa, never remove audio from mics that have been removed - if cfg.toa_prob > 0.: + if cfg.toa_prob > 0.0: start_idx = cfg.num_mics num_audio = cfg.num_mics else: start_idx = cfg.num_mics + n_patch_per_mic num_audio = cfg.num_mics - 1 - all_audio = np.expand_dims(np.arange(start_idx, cfg.num_mics - + cfg.num_mics * n_patch_per_mic), axis=0).repeat(bs, axis=0) + all_audio = np.expand_dims( + np.arange(start_idx, cfg.num_mics + cfg.num_mics * n_patch_per_mic), axis=0 + ).repeat(bs, axis=0) all_audio = all_audio.reshape(bs, n_patch_per_mic, num_audio) - - if cfg.toa_prob > 0.: + + if cfg.toa_prob > 0.0: # toa_idx = [] all_audio_new = np.zeros_like(all_audio)[:, :, :-1] for b in range(bs): if np.random.rand() > cfg.toa_prob: - all_audio_new[b] = all_audio[b, :, 1:] # remove source audio + all_audio_new[b] = all_audio[b, :, 1:] # remove source audio else: - rand_mic = np.random.randint(low=2, high=cfg.num_mics-1) - all_audio_new[b] = np.concatenate((all_audio[b, :, :rand_mic], all_audio[b, :, rand_mic+1:]), axis=-1) # remove another microphone + rand_mic = np.random.randint(low=2, high=cfg.num_mics - 1) + all_audio_new[b] = np.concatenate( + (all_audio[b, :, :rand_mic], all_audio[b, :, rand_mic + 1 :]), + axis=-1, + ) # remove another microphone all_audio = all_audio_new # random masking for b in range(bs): - perm = np.random.permutation(cfg.num_mics-1) + perm = np.random.permutation(cfg.num_mics - 1) all_mics[b] = all_mics[b, perm] - all_audio[b] = all_audio[b, :, perm].transpose(1,0) - + all_audio[b] = all_audio[b, :, perm].transpose(1, 0) + # sample random n_keep - n_mic_keep = np.random.randint(low=cfg.n_mic_keep[0], high=cfg.n_mic_keep[1]+1) - n_audio_keep = np.random.randint(low=cfg.n_audio_keep[0], high=cfg.n_audio_keep[1]+1) - + n_mic_keep = np.random.randint( + low=cfg.n_mic_keep[0], high=cfg.n_mic_keep[1] + 1 + ) + n_audio_keep = np.random.randint( + low=cfg.n_audio_keep[0], high=cfg.n_audio_keep[1] + 1 + ) + mics_keep = torch.LongTensor(all_mics[:, :n_mic_keep]).to(device) - audio_keep = all_audio[:, :, -n_audio_keep:].reshape(bs, n_patch_per_mic * n_audio_keep) + audio_keep = all_audio[:, :, -n_audio_keep:].reshape( + bs, n_patch_per_mic * n_audio_keep + ) audio_keep = torch.LongTensor(audio_keep).to(device) - + this_ids_keep = torch.cat((mics_keep, audio_keep), dim=1) this_ids_keep, _ = torch.sort(this_ids_keep, dim=1) - elif cfg.random_masking == 'fixed_number' or cfg.random_masking == 'random_same': + elif cfg.random_masking == "fixed_number" or cfg.random_masking == "random_same": all_mics = np.expand_dims(np.arange(1, cfg.num_mics), axis=0).repeat(bs, axis=0) - all_audio = np.expand_dims(np.arange(cfg.num_mics + n_patch_per_mic, cfg.num_mics - + cfg.num_mics * n_patch_per_mic), axis=0).repeat(bs, axis=0) - all_audio = all_audio.reshape(bs, n_patch_per_mic, cfg.num_mics-1) + all_audio = np.expand_dims( + np.arange( + cfg.num_mics + n_patch_per_mic, + cfg.num_mics + cfg.num_mics * n_patch_per_mic, + ), + axis=0, + ).repeat(bs, axis=0) + all_audio = all_audio.reshape(bs, n_patch_per_mic, cfg.num_mics - 1) # randomly remove some audio, and remove the same microphone coordinates for b in range(bs): - perm = np.random.permutation(cfg.num_mics-1) + perm = np.random.permutation(cfg.num_mics - 1) all_mics[b] = all_mics[b, perm] - all_audio[b] = all_audio[b, :, perm].transpose(1,0) + all_audio[b] = all_audio[b, :, perm].transpose(1, 0) - if cfg.random_masking == 'random_same': - n_keep = np.random.randint(low=cfg.n_mic_keep[0], high=cfg.n_mic_keep[1]+1) + if cfg.random_masking == "random_same": + n_keep = np.random.randint( + low=cfg.n_mic_keep[0], high=cfg.n_mic_keep[1] + 1 + ) else: n_keep = cfg.n_mic_keep mics_keep = torch.LongTensor(all_mics[:, :n_keep]).to(device) audio_keep = all_audio[:, :, :n_keep].reshape(bs, n_patch_per_mic * n_keep) audio_keep = torch.LongTensor(audio_keep).to(device) - + this_ids_keep = torch.cat((mics_keep, audio_keep), dim=1) this_ids_keep, _ = torch.sort(this_ids_keep, dim=1) elif not cfg.random_masking: this_ids_keep = ids_keep.repeat(bs, 1) else: - raise ValueError('Select a valid masking strategy') + raise ValueError("Select a valid masking strategy") - return this_ids_keep \ No newline at end of file + return this_ids_keep diff --git a/wav2pos.py b/wav2pos.py index 46df7a8..b56b16c 100644 --- a/wav2pos.py +++ b/wav2pos.py @@ -15,11 +15,20 @@ def forward(self, x): x = self.bn(x.permute(0, 2, 1)) return x.permute(0, 2, 1) -class PatchEmbedAudio(nn.Module): - """ Audio to Patch Embedding - """ - def __init__(self, audio_len=2048, patch_size=16, num_mics=1, embed_dim=768, decoder_embed_dim=768, norm_layer=None, flatten=True): +class PatchEmbedAudio(nn.Module): + """Audio to Patch Embedding""" + + def __init__( + self, + audio_len=2048, + patch_size=16, + num_mics=1, + embed_dim=768, + decoder_embed_dim=768, + norm_layer=None, + flatten=True, + ): super().__init__() self.audio_len = audio_len self.patch_size = patch_size @@ -29,9 +38,9 @@ def __init__(self, audio_len=2048, patch_size=16, num_mics=1, embed_dim=768, dec self.flatten = flatten self.embed_dim = embed_dim - - self.proj = nn.Conv2d(1, embed_dim, kernel_size=( - 1, patch_size), stride=(1, patch_size)) + self.proj = nn.Conv2d( + 1, embed_dim, kernel_size=(1, patch_size), stride=(1, patch_size) + ) self.projT = nn.Linear(decoder_embed_dim, patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() @@ -39,15 +48,16 @@ def __init__(self, audio_len=2048, patch_size=16, num_mics=1, embed_dim=768, dec def forward(self, x): x = x.unsqueeze(1) B, _, N, L = x.shape - assert L == self.audio_len, \ - f"Input audio length ({L}) doesn't match model ({self.audio_len})." - + assert L == self.audio_len, ( + f"Input audio length ({L}) doesn't match model ({self.audio_len})." + ) + x = self.proj(x) x = x.flatten(2) # BCNL-> BCM x = x.transpose(1, 2) # BCM-> BMC x = self.norm(x) return x - + def forwardT(self, x): x = self.projT(x) x = x.squeeze() @@ -55,31 +65,52 @@ def forwardT(self, x): class wav2pos(nn.Module): - """ Masked Autoencoder for audio and positions - """ - - def __init__(self, audio_len=2048, patch_size=16, num_mics=3, - embed_dim=512, depth=4, num_heads=4, - decoder_embed_dim=256, decoder_depth=4, decoder_num_heads=4, - mlp_ratio=4., norm_layer=nn.LayerNorm, drop=0.0, attn_drop=0.0, - pos_dim=3, snr_interval=[5, 30], all_patch_loss=True, - use_ngcc=False, ngcc_path=None, use_maxpool=True, use_posenc=True, max_tau=314): + """Masked Autoencoder for audio and positions""" + + def __init__( + self, + audio_len=2048, + patch_size=16, + num_mics=3, + embed_dim=512, + depth=4, + num_heads=4, + decoder_embed_dim=256, + decoder_depth=4, + decoder_num_heads=4, + mlp_ratio=4.0, + norm_layer=nn.LayerNorm, + drop=0.0, + attn_drop=0.0, + pos_dim=3, + snr_interval=[5, 30], + all_patch_loss=True, + use_ngcc=False, + ngcc_path=None, + use_maxpool=True, + use_posenc=True, + max_tau=314, + ): super().__init__() self.use_ngcc = use_ngcc if self.use_ngcc: - self.max_tau = max_tau - self.ngcc = masked_NGCCPHAT(max_tau=self.max_tau, snr_interval=[1000, 1000], - num_mics=num_mics, head='classifier') - if self.use_ngcc == 'pre-trained': + self.max_tau = max_tau + self.ngcc = masked_NGCCPHAT( + max_tau=self.max_tau, + snr_interval=[1000, 1000], + num_mics=num_mics, + head="classifier", + ) + if self.use_ngcc == "pre-trained": self.ngcc.eval() - print('loading ngcc pre-trained weights from ' + ngcc_path) + print("loading ngcc pre-trained weights from " + ngcc_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.ngcc.load_state_dict(torch.load(ngcc_path, map_location=device)) - self.patch_embed = PatchEmbedAudio( - audio_len, patch_size, num_mics, embed_dim, decoder_embed_dim) + audio_len, patch_size, num_mics, embed_dim, decoder_embed_dim + ) self.num_mics = num_mics self.drop = drop self.attn_drop = attn_drop @@ -104,18 +135,34 @@ def __init__(self, audio_len=2048, patch_size=16, num_mics=3, self.enc_loc_modality = nn.Parameter(torch.zeros(1, 1, embed_dim)) # pair-wise positional encoding layers - self.get_decoder_mic_feature = nn.Sequential(nn.Linear(decoder_embed_dim, decoder_embed_dim), - nn.LayerNorm(decoder_embed_dim), - nn.GELU()) - self.decoder_fproj = nn.Sequential(nn.Linear(decoder_embed_dim, decoder_embed_dim), - nn.LayerNorm(decoder_embed_dim)) - self.get_decoder_audio_features = nn.Sequential(nn.Linear(decoder_embed_dim, decoder_embed_dim), - nn.LayerNorm(decoder_embed_dim)) - - self.blocks = nn.ModuleList([ - Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, - norm_layer=norm_layer, proj_drop=self.drop, attn_drop=self.attn_drop) - for i in range(depth)]) + self.get_decoder_mic_feature = nn.Sequential( + nn.Linear(decoder_embed_dim, decoder_embed_dim), + nn.LayerNorm(decoder_embed_dim), + nn.GELU(), + ) + self.decoder_fproj = nn.Sequential( + nn.Linear(decoder_embed_dim, decoder_embed_dim), + nn.LayerNorm(decoder_embed_dim), + ) + self.get_decoder_audio_features = nn.Sequential( + nn.Linear(decoder_embed_dim, decoder_embed_dim), + nn.LayerNorm(decoder_embed_dim), + ) + + self.blocks = nn.ModuleList( + [ + Block( + embed_dim, + num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + proj_drop=self.drop, + attn_drop=self.attn_drop, + ) + for i in range(depth) + ] + ) self.patch_norm = norm_layer(patch_size) self.norm = norm_layer(embed_dim) @@ -125,29 +172,36 @@ def __init__(self, audio_len=2048, patch_size=16, num_mics=3, self.mask_token_source = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) self.mask_token_mic = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) self.mask_token_audio = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) - - # decoder modality token: - self.dec_audio_modality = nn.Parameter( - torch.zeros(1, 1, decoder_embed_dim)) - self.dec_loc_modality = nn.Parameter( - torch.zeros(1, 1, decoder_embed_dim)) - self.decoder_blocks = nn.ModuleList([ - Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, - norm_layer=norm_layer, proj_drop=self.drop, attn_drop=self.attn_drop) - for i in range(decoder_depth)]) + # decoder modality token: + self.dec_audio_modality = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + self.dec_loc_modality = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + + self.decoder_blocks = nn.ModuleList( + [ + Block( + decoder_embed_dim, + decoder_num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + proj_drop=self.drop, + attn_drop=self.attn_drop, + ) + for i in range(decoder_depth) + ] + ) self.decoder_norm = norm_layer(decoder_embed_dim) self.decoder_max = nn.Linear(decoder_embed_dim, decoder_embed_dim) if self.use_ngcc is not False and self.use_maxpool: - self.loc_mlp = nn.Sequential( - nn.Linear(2 * decoder_embed_dim + 2 * self.max_tau + 1, 512), + nn.Linear(2 * decoder_embed_dim + 2 * self.max_tau + 1, 512), PBatchNorm1d(512), nn.GELU(), - nn.Linear(512, 512) + nn.Linear(512, 512), ) self.loc_proj = nn.Linear(512, decoder_embed_dim) @@ -158,23 +212,23 @@ def __init__(self, audio_len=2048, patch_size=16, num_mics=3, n_feat = 1 self.decoder_pred_source = nn.Sequential( - nn.Linear(n_feat * decoder_embed_dim, 512), + nn.Linear(n_feat * decoder_embed_dim, 512), PBatchNorm1d(512), nn.GELU(), nn.Linear(512, 256), PBatchNorm1d(256), nn.GELU(), - nn.Linear(256, self.pos_dim) + nn.Linear(256, self.pos_dim), ) self.decoder_pred_locs = nn.Sequential( - nn.Linear(n_feat * decoder_embed_dim, 512), + nn.Linear(n_feat * decoder_embed_dim, 512), PBatchNorm1d(512), nn.GELU(), nn.Linear(512, 256), PBatchNorm1d(256), nn.GELU(), - nn.Linear(256, self.pos_dim) + nn.Linear(256, self.pos_dim), ) # -------------------------------------------------------------------------- @@ -184,36 +238,40 @@ def __init__(self, audio_len=2048, patch_size=16, num_mics=3, fs = int(16e3) - self.transform = AddColoredNoise(p=1.0, min_snr_in_db=snr_interval[0], max_snr_in_db=snr_interval[1], sample_rate=fs, mode="per_channel", p_mode="per_channel") - - + self.transform = AddColoredNoise( + p=1.0, + min_snr_in_db=snr_interval[0], + max_snr_in_db=snr_interval[1], + sample_rate=fs, + mode="per_channel", + p_mode="per_channel", + ) def initialize_weights(self): # initialization - torch.nn.init.normal_(self.mask_token_mic, std=.02) - torch.nn.init.normal_(self.mask_token_audio, std=.02) + torch.nn.init.normal_(self.mask_token_mic, std=0.02) + torch.nn.init.normal_(self.mask_token_audio, std=0.02) torch.nn.init.normal_(self.mask_token_source, std=0.02) # modality encoders - torch.nn.init.normal_(self.enc_audio_modality, std=.02) - torch.nn.init.normal_(self.enc_loc_modality, std=.02) - torch.nn.init.normal_(self.dec_audio_modality, std=.02) - torch.nn.init.normal_(self.dec_loc_modality, std=.02) + torch.nn.init.normal_(self.enc_audio_modality, std=0.02) + torch.nn.init.normal_(self.enc_loc_modality, std=0.02) + torch.nn.init.normal_(self.dec_audio_modality, std=0.02) + torch.nn.init.normal_(self.dec_loc_modality, std=0.02) # initialize nn.Linear and nn.LayerNorm self.apply(self._init_normal) def _init_normal(self, m): if isinstance(m, nn.Linear): - torch.nn.init.normal_(m.weight, std=.02) + torch.nn.init.normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - def patchify(self, signals): """ signals: (N, 1, n_mics, L) @@ -225,8 +283,7 @@ def patchify(self, signals): assert signals.shape[3] % p == 0 x = signals.squeeze(1) - x = x.reshape( - shape=(signals.shape[0], num_mics, num_patches // num_mics, p)) + x = x.reshape(shape=(signals.shape[0], num_mics, num_patches // num_mics, p)) x = x.flatten(1, 2) return x @@ -241,24 +298,20 @@ def unpatchify(self, x): assert x.shape[2] % p == 0 x = x.unflatten(1, (num_mics, num_patches // num_mics)) - x = x.reshape(shape=(x.shape[0], num_mics, - num_patches * p // num_mics)) + x = x.reshape(shape=(x.shape[0], num_mics, num_patches * p // num_mics)) signals = x.unsqueeze(1) return signals def mask(self, x, ids_keep): - N, L, D = x.shape mask = torch.ones([N, L], device=x.device) replace = torch.zeros(ids_keep.size(), device=x.device) mask = mask.scatter(dim=1, index=ids_keep, src=replace) - x_masked = torch.gather( - x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) return x_masked, mask def forward_encoder(self, x, locations, ids_keep): - # embed patches x = self.patchify(x) mu = torch.mean(x, -1, keepdim=True) @@ -272,11 +325,10 @@ def forward_encoder(self, x, locations, ids_keep): x = self.patch_embed.norm(x) b, n, d = x.shape - - #modality tokens + + # modality tokens mod_loc = self.enc_loc_modality.repeat(b, self.patch_embed.num_mics, 1) - mod_audio = self.enc_audio_modality.repeat( - b, self.patch_embed.num_patches, 1) + mod_audio = self.enc_audio_modality.repeat(b, self.patch_embed.num_patches, 1) mod_token = torch.cat((mod_loc, mod_audio), dim=1) x = x + mod_token @@ -295,54 +347,65 @@ def forward_decoder(self, x_masked, mask, feature=None, ids_keep=None): x_masked = self.decoder_embed(x_masked) # insert mask tokens - num_masked_mic_patches = int(mask[0, :self.patch_embed.num_mics].sum())# x_masked[:, :self.patch_embed.num_mics] - num_masked_audio_patches = int(mask[0, self.patch_embed.num_mics:].sum())# x_masked[:, self.patch_embed.num_mics:] + num_masked_mic_patches = int( + mask[0, : self.patch_embed.num_mics].sum() + ) # x_masked[:, :self.patch_embed.num_mics] + num_masked_audio_patches = int( + mask[0, self.patch_embed.num_mics :].sum() + ) # x_masked[:, self.patch_embed.num_mics:] N, _, D = x_masked.shape _, L = mask.shape mask_tokens_source = self.mask_token_source.repeat(N, 1, 1) - mask_tokens_mic = self.mask_token_mic.repeat(N, num_masked_mic_patches-1, 1) + mask_tokens_mic = self.mask_token_mic.repeat(N, num_masked_mic_patches - 1, 1) mask_tokens_audio = self.mask_token_audio.repeat(N, num_masked_audio_patches, 1) - mask_tokens = torch.cat((mask_tokens_source, mask_tokens_mic, mask_tokens_audio), dim=1) - + mask_tokens = torch.cat( + (mask_tokens_source, mask_tokens_mic, mask_tokens_audio), dim=1 + ) + non_masked = (mask == 0).nonzero()[:, 1].reshape([N, -1]) masked = mask.nonzero()[:, 1].reshape([N, -1]) - x = torch.zeros(N, L, D, device=x_masked.device, - requires_grad=True).clone() + x = torch.zeros(N, L, D, device=x_masked.device, requires_grad=True).clone() x = x.scatter( - dim=1, index=non_masked.unsqueeze(-1).repeat(1, 1, D), src=x_masked) + dim=1, index=non_masked.unsqueeze(-1).repeat(1, 1, D), src=x_masked + ) x = x.scatter( - dim=1, index=masked.unsqueeze(-1).repeat(1, 1, D), src=mask_tokens) + dim=1, index=masked.unsqueeze(-1).repeat(1, 1, D), src=mask_tokens + ) - # add pair-wise position embedding # add microphone feature from audio - audio_features = x[:, self.patch_embed.num_mics:].reshape([N, - self.patch_embed.num_patches // self.patch_embed.num_mics, - self.patch_embed.num_mics, - D]) + audio_features = x[:, self.patch_embed.num_mics :].reshape( + [ + N, + self.patch_embed.num_patches // self.patch_embed.num_mics, + self.patch_embed.num_mics, + D, + ] + ) - mic_features = x[:, :self.patch_embed.num_mics] + mic_features = x[:, : self.patch_embed.num_mics] f_mic = self.get_decoder_mic_feature(audio_features) f_mic, _ = torch.max(f_mic, dim=1, keepdim=False) f_mic = self.decoder_fproj(f_mic) # audio feature from microphones - f_audio = self.get_decoder_audio_features(mic_features).repeat(1, self.patch_embed.num_patches // self.patch_embed.num_mics, 1) - + f_audio = self.get_decoder_audio_features(mic_features).repeat( + 1, self.patch_embed.num_patches // self.patch_embed.num_mics, 1 + ) + # add position embedding pos_enc = torch.cat((f_mic, f_audio), dim=1) if self.use_posenc: x = x + pos_enc - #modality tokens + # modality tokens b, _, _ = x.shape mod_loc = self.dec_loc_modality.repeat(b, self.patch_embed.num_mics, 1) - mod_audio = self.dec_audio_modality.repeat( - b, self.patch_embed.num_patches, 1) + mod_audio = self.dec_audio_modality.repeat(b, self.patch_embed.num_patches, 1) mod_token = torch.cat((mod_loc, mod_audio), dim=1) x = x + mod_token @@ -352,20 +415,24 @@ def forward_decoder(self, x_masked, mask, feature=None, ids_keep=None): x = self.decoder_norm(x) # split audio/localization - locs = x[:, :self.patch_embed.num_mics] - audio = x[:, self.patch_embed.num_mics:] + locs = x[:, : self.patch_embed.num_mics] + audio = x[:, self.patch_embed.num_mics :] # global feature for locations x_max, _ = torch.max(self.decoder_max(x_masked), dim=1, keepdim=True) x_max = x_max.repeat(1, self.patch_embed.num_mics, 1) - + if self.use_ngcc is not None and self.use_maxpool: loc_features = [] - num_non_masked_mic_patches = self.patch_embed.num_mics - int(mask[0, :self.patch_embed.num_mics].sum()) - ids_keep_audio = ids_keep[:, num_non_masked_mic_patches:] - self.patch_embed.num_mics + num_non_masked_mic_patches = self.patch_embed.num_mics - int( + mask[0, : self.patch_embed.num_mics].sum() + ) + ids_keep_audio = ( + ids_keep[:, num_non_masked_mic_patches:] - self.patch_embed.num_mics + ) locs_masked, _ = self.mask(locs, ids_keep=ids_keep_audio) for i in range(0, locs_masked.shape[1]): - for j in range(i+1, locs_masked.shape[1]): + for j in range(i + 1, locs_masked.shape[1]): p1 = locs_masked[:, i, :] p2 = locs_masked[:, j, :] p_both1 = torch.cat((p1, p2), dim=1) @@ -377,7 +444,9 @@ def forward_decoder(self, x_masked, mask, feature=None, ids_keep=None): loc_features = torch.cat((loc_features, feature), dim=2) loc_features = self.loc_mlp(loc_features) loc_features, _ = torch.max(loc_features, dim=1, keepdim=True) - loc_features = self.loc_proj(loc_features).repeat(1, self.patch_embed.num_mics, 1) + loc_features = self.loc_proj(loc_features).repeat( + 1, self.patch_embed.num_mics, 1 + ) locs = torch.cat((locs, x_max, loc_features), dim=-1) elif self.use_maxpool: @@ -394,7 +463,6 @@ def forward_decoder(self, x_masked, mask, feature=None, ids_keep=None): locs = torch.cat((source, locs), dim=1) - return audio, locs def forward_loss(self, imgs, pred, pred_locs, locs, mask): @@ -410,18 +478,20 @@ def forward_loss(self, imgs, pred, pred_locs, locs, mask): target = (target - mu) / sigma - mask_locs = mask[:, 1:self.patch_embed.num_mics] - mask_audio = mask[:, self.patch_embed.num_mics:] + mask_locs = mask[:, 1 : self.patch_embed.num_mics] + mask_audio = mask[:, self.patch_embed.num_mics :] loss_audio = (pred - target) ** 2 loss_audio = loss_audio.mean(dim=-1) # [N, L], mean loss per patch - - loss_audio = (loss_audio * (1.0 - mask_audio)).sum() / (1.0 - mask_audio).sum() # mean loss on non-masked patches + + loss_audio = (loss_audio * (1.0 - mask_audio)).sum() / ( + 1.0 - mask_audio + ).sum() # mean loss on non-masked patches loss_source = (pred_locs[:, 0] - locs[:, 0]) ** 2 loss_source = loss_source.mean() - loss_locs = (pred_locs[:,1:] - locs[:, 1:]) ** 2 + loss_locs = (pred_locs[:, 1:] - locs[:, 1:]) ** 2 loss_locs = loss_locs.mean(dim=-1) if self.all_patch_loss: loss_locs = loss_locs.mean() @@ -429,40 +499,44 @@ def forward_loss(self, imgs, pred, pred_locs, locs, mask): if mask_locs.sum() > 0: loss_locs = (loss_locs * mask_locs).sum() / mask_locs.sum() else: - loss_locs = 0. + loss_locs = 0.0 loss_locs = loss_locs + loss_source return loss_audio, loss_locs - def forward(self, audio, locations, ids_keep, mode='test'): - if mode == 'train': + def forward(self, audio, locations, ids_keep, mode="test"): + if mode == "train": x = self.transform(audio.squeeze(1)).unsqueeze(1) else: x = audio - + target = audio - latent, mask, mu, sigma = self.forward_encoder( - x, locations, ids_keep) - - num_non_masked_mic_patches = self.patch_embed.num_mics - int(mask[0, :self.patch_embed.num_mics].sum()) - ids_keep_audio = ids_keep[:, num_non_masked_mic_patches:] - self.patch_embed.num_mics + latent, mask, mu, sigma = self.forward_encoder(x, locations, ids_keep) + + num_non_masked_mic_patches = self.patch_embed.num_mics - int( + mask[0, : self.patch_embed.num_mics].sum() + ) + ids_keep_audio = ( + ids_keep[:, num_non_masked_mic_patches:] - self.patch_embed.num_mics + ) x_masked, _ = self.mask(x.squeeze(1), ids_keep=ids_keep_audio) - if self.use_ngcc == 'gccphat': + if self.use_ngcc == "gccphat": with torch.no_grad(): feature = self.ngcc.get_gccphat_features(x_masked, ids_keep="all") - elif self.use_ngcc == 'pre-trained': + elif self.use_ngcc == "pre-trained": with torch.no_grad(): feature = self.ngcc.get_features(x_masked, ids_keep="all") elif not self.use_ngcc: feature = None else: - raise ValueError('select valid ngcc format') + raise ValueError("select valid ngcc format") - pred, pred_locs = self.forward_decoder(latent, mask, feature, ids_keep) # [N, L, p*p*3] + pred, pred_locs = self.forward_decoder( + latent, mask, feature, ids_keep + ) # [N, L, p*p*3] loss_audio, loss_locs = self.forward_loss( - target, pred, pred_locs, locations, mask) + target, pred, pred_locs, locations, mask + ) return loss_audio, loss_locs, pred, pred_locs, mask, mu, sigma - -