@@ -214,8 +214,9 @@ def forward(self, input_r, input_i):
214214 mean = torch .stack ((mean_r ,mean_i ),dim = 1 )
215215
216216 # update running mean
217- self .running_mean = exponential_average_factor * mean \
218- + (1 - exponential_average_factor ) * self .running_mean
217+ with torch .no_grad ():
218+ self .running_mean = exponential_average_factor * mean \
219+ + (1 - exponential_average_factor ) * self .running_mean
219220
220221 input_r = input_r - mean_r [None , :, None , None ]
221222 input_i = input_i - mean_i [None , :, None , None ]
@@ -226,14 +227,15 @@ def forward(self, input_r, input_i):
226227 Cii = 1. / n * input_i .pow (2 ).sum (dim = [0 ,2 ,3 ])+ self .eps
227228 Cri = (input_r .mul (input_i )).mean (dim = [0 ,2 ,3 ])
228229
229- self .running_covar [:,0 ] = exponential_average_factor * Crr * n / (n - 1 )\
230- + (1 - exponential_average_factor ) * self .running_covar [:,0 ]
230+ with torch .no_grad ():
231+ self .running_covar [:,0 ] = exponential_average_factor * Crr * n / (n - 1 )\
232+ + (1 - exponential_average_factor ) * self .running_covar [:,0 ]
231233
232- self .running_covar [:,1 ] = exponential_average_factor * Cii * n / (n - 1 )\
233- + (1 - exponential_average_factor ) * self .running_covar [:,1 ]
234+ self .running_covar [:,1 ] = exponential_average_factor * Cii * n / (n - 1 )\
235+ + (1 - exponential_average_factor ) * self .running_covar [:,1 ]
234236
235- self .running_covar [:,2 ] = exponential_average_factor * Cri * n / (n - 1 )\
236- + (1 - exponential_average_factor ) * self .running_covar [:,2 ]
237+ self .running_covar [:,2 ] = exponential_average_factor * Cri * n / (n - 1 )\
238+ + (1 - exponential_average_factor ) * self .running_covar [:,2 ]
237239
238240 else :
239241 mean = self .running_mean
@@ -291,8 +293,9 @@ def forward(self, input_r, input_i):
291293 mean = torch .stack ((mean_r ,mean_i ),dim = 1 )
292294
293295 # update running mean
294- self .running_mean = exponential_average_factor * mean \
295- + (1 - exponential_average_factor ) * self .running_mean
296+ with torch .no_grad ():
297+ self .running_mean = exponential_average_factor * mean \
298+ + (1 - exponential_average_factor ) * self .running_mean
296299
297300 # zero mean values
298301 input_r = input_r - mean_r [None , :]
@@ -305,14 +308,15 @@ def forward(self, input_r, input_i):
305308 Cii = input_i .var (dim = 0 ,unbiased = False )+ self .eps
306309 Cri = (input_r .mul (input_i )).mean (dim = 0 )
307310
308- self .running_covar [:,0 ] = exponential_average_factor * Crr * n / (n - 1 )\
309- + (1 - exponential_average_factor ) * self .running_covar [:,0 ]
311+ with torch .no_grad ():
312+ self .running_covar [:,0 ] = exponential_average_factor * Crr * n / (n - 1 )\
313+ + (1 - exponential_average_factor ) * self .running_covar [:,0 ]
310314
311- self .running_covar [:,1 ] = exponential_average_factor * Cii * n / (n - 1 )\
312- + (1 - exponential_average_factor ) * self .running_covar [:,1 ]
315+ self .running_covar [:,1 ] = exponential_average_factor * Cii * n / (n - 1 )\
316+ + (1 - exponential_average_factor ) * self .running_covar [:,1 ]
313317
314- self .running_covar [:,2 ] = exponential_average_factor * Cri * n / (n - 1 )\
315- + (1 - exponential_average_factor ) * self .running_covar [:,2 ]
318+ self .running_covar [:,2 ] = exponential_average_factor * Cri * n / (n - 1 )\
319+ + (1 - exponential_average_factor ) * self .running_covar [:,2 ]
316320
317321 else :
318322 mean = self .running_mean
0 commit comments