@@ -152,3 +152,128 @@ def build_viz_dataset(cfg):
152
152
norm_stats = cfg .norm_stats if 'norm_stats' in cfg else None
153
153
ds = SpectrogramDataset (folder = cfg .data_path , files = files , crop_frames = cfg .input_size [1 ], tfms = None , norm_stats = norm_stats )
154
154
return ds , files
155
+
156
+
157
+ # Mixed dataset
158
+
159
+ def log_mixup_exp (xa , xb , alpha ):
160
+ xa = xa .exp ()
161
+ xb = xb .exp ()
162
+ x = alpha * xa + (1. - alpha ) * xb
163
+ return torch .log (torch .max (x , torch .finfo (x .dtype ).eps * torch .ones_like (x )))
164
+
165
+
166
+ class MixedSpecDataset (torch .utils .data .Dataset ):
167
+ def __init__ (self , base_folder , files_main , files_bg_noise , crop_size , noise_ratio = 0.0 ,
168
+ random_crop = True , n_norm_calc = 10000 ) -> None :
169
+ super ().__init__ ()
170
+
171
+ self .ds1 = SpectrogramDataset (folder = base_folder , files = files_main , crop_frames = crop_size [1 ],
172
+ random_crop = random_crop , norm_stats = None ,
173
+ n_norm_calc = n_norm_calc // 2 )
174
+ self .norm_stats = self .ds1 .norm_stats # for compatibility with SpectrogramDataset
175
+ # disable normalizion scaling in the ds1
176
+ self .norm_std = self .ds1 .norm_stats [1 ]
177
+ self .ds1 .norm_stats = (self .ds1 .norm_stats [0 ], 1.0 )
178
+
179
+ if noise_ratio > 0.0 :
180
+ self .ds2 = SpectrogramDataset (folder = base_folder , files = files_bg_noise , crop_frames = crop_size [1 ],
181
+ random_crop = random_crop , norm_stats = None , n_norm_calc = n_norm_calc // 2 , repeat_short = True )
182
+ self .ds2 .norm_stats = (self .ds2 .norm_stats [0 ], 1.0 ) # disable normalizion scaling in the ds2
183
+
184
+ self .noise_ratio = noise_ratio
185
+ self .bg_index = []
186
+
187
+ def __len__ (self ):
188
+ return len (self .ds1 )
189
+
190
+ def __getitem__ (self , index , fixed_noise = False ):
191
+ # load index sample
192
+ clean = self .ds1 [index ]
193
+ if self .noise_ratio > 0.0 :
194
+ # load random noise sample ### , while making noise floor zero
195
+ noise = self .ds2 [index if fixed_noise else self .get_next_bgidx ()]
196
+ # mix
197
+ mixed = log_mixup_exp (noise , clean , self .noise_ratio ) if self .noise_ratio < 1.0 else noise
198
+ else :
199
+ mixed = clean .clone ()
200
+ # finish normalization. clean and noise were averaged to zero. the following will scale to 1.0 using ds1 std.
201
+ clean = clean / self .norm_std
202
+ mixed = mixed / self .norm_std
203
+ return clean , mixed
204
+
205
+
206
+ def get_next_bgidx (self ):
207
+ if len (self .bg_index ) == 0 :
208
+ self .bg_index = torch .randperm (len (self .ds2 )).tolist ()
209
+ # print(f'Refreshed the bg index list with {len(self.bg_index)} items: {self.bg_index[:5]}...')
210
+ return self .bg_index .pop (0 )
211
+
212
+ def __repr__ (self ):
213
+ format_string = self .__class__ .__name__ + f'(crop_frames={ self .ds1 .crop_frames } , '
214
+ format_string += f'folder_sp={ self .ds1 .df .file_name .values [0 ].split ("/" )[0 ]} , '
215
+ if self .noise_ratio > 0. : format_string += f'folder_bg={ self .ds2 .df .file_name .values [0 ].split ("/" )[0 ]} , '
216
+ return format_string
217
+
218
+
219
+ def inflate_files (files , desired_size ):
220
+ if len (files ) == 0 :
221
+ return files
222
+ files = list (files ) # make sure `files`` is a list
223
+ while len (files ) < desired_size :
224
+ files = (files + files )[:desired_size ]
225
+ return files
226
+
227
+
228
+ def build_mixed_dataset (cfg ):
229
+ """The followings configure the training dataset details.
230
+ - data_path: Root folder of the training dataset.
231
+ - dataset: The _name_ of the training dataset, an stem name of a `.csv` training data list.
232
+ - norm_stats: Normalization statistics, a list of [mean, std].
233
+ - input_size: Input size, a list of [# of freq. bins, # of time frames].
234
+ """
235
+
236
+ # get files and inflate the number of files (by repeating the list) if needed
237
+ files_main = get_files (cfg .csv_main )
238
+ files_bg = get_files (cfg .csv_bg_noise ) if cfg .noise_ratio > 0. else []
239
+ desired_min_size = 0
240
+ if 'min_ds_size' in cfg and cfg .min_ds_size > 0 :
241
+ desired_min_size = cfg .min_ds_size
242
+ if desired_min_size > 0 :
243
+ old_sizes = len (files_main ), len (files_bg )
244
+ files_main , files_bg = inflate_files (files_main , desired_min_size ), inflate_files (files_bg , desired_min_size )
245
+ print ('The numbers of data files are increased from' , old_sizes , 'to' , (len (files_main ), len (files_bg )))
246
+
247
+ ds = MixedSpecDataset (
248
+ base_folder = cfg .data_path , files_main = files_main ,
249
+ files_bg_noise = files_bg ,
250
+ crop_size = cfg .input_size ,
251
+ noise_ratio = cfg .noise_ratio ,
252
+ random_crop = True )
253
+ if 'weighted' in cfg and cfg .weighted :
254
+ assert desired_min_size == 0
255
+ ds .weight = pd .read_csv (cfg .csv_main ).weight .values
256
+
257
+ val_ds = SpectrogramDataset (folder = cfg .data_path , files = get_files (cfg .csv_val ), crop_frames = cfg .input_size [1 ], random_crop = True ) \
258
+ if cfg .csv_val else None
259
+
260
+ return ds , val_ds
261
+
262
+
263
+ def build_mixed_viz_dataset (cfg ):
264
+ files = [str (f ).replace (str (cfg .data_path ) + '/' , '' ) for f in sorted (Path (cfg .data_path ).glob ('vis_samples/*.npy' ))]
265
+ if len (files ) == 0 :
266
+ return None , []
267
+ norm_stats = cfg .norm_stats if 'norm_stats' in cfg else None
268
+ ds = SpectrogramDataset (folder = cfg .data_path , files = files , crop_frames = cfg .input_size [1 ], tfms = None , norm_stats = norm_stats )
269
+ return ds , files
270
+
271
+
272
+ if __name__ == '__main__' :
273
+ # Test
274
+ ds = MixedSpecDataset (base_folder = 'data' , files_main = get_files ('data/files_gtzan.csv' ),
275
+ files_bg_noise = get_files ('data/files_audioset.csv' ),
276
+ crop_size = [80 , 608 ], noise_ratio = 0.2 , random_crop = True , n_norm_calc = 10 )
277
+ for i in range (0 , 10 ):
278
+ clean , mixed = ds [i ]
279
+ print (clean .shape , mixed .shape )
0 commit comments