@@ -13,16 +13,16 @@ def complex_relu(input_r,input_i):
1313def complex_max_pool2d (input_r ,input_i ,kernel_size , stride = None , padding = 0 ,
1414 dilation = 1 , ceil_mode = False , return_indices = False ):
1515
16- return max_pool2d (input_r , kernel_size , stride , padding , dilation ,
16+ return max_pool2d (input_r , kernel_size , stride , padding , dilation ,
1717 ceil_mode , return_indices ), \
18- max_pool2d (input_i , kernel_size , stride , padding , dilation ,
18+ max_pool2d (input_i , kernel_size , stride , padding , dilation ,
1919 ceil_mode , return_indices )
2020
2121def complex_dropout (input_r ,input_i , p = 0.5 , training = True , inplace = False ):
22- return complex_dropout (input_r , p , training , inplace ), \
23- complex_dropout ( input_r , p , training , inplace )
22+ return dropout (input_r , p , training , inplace ), \
23+ dropout ( input_i , p , training , inplace )
2424
2525
2626def complex_dropout2d (input_r ,input_i , p = 0.5 , training = True , inplace = False ):
27- return complex_dropout2d (input_r , p , training , inplace ), \
28- complex_dropout2d ( input_r , p , training , inplace )
27+ return dropout2d (input_r , p , training , inplace ), \
28+ dropout2d ( input_i , p , training , inplace )
0 commit comments