|
12 | 12 | import torch |
13 | 13 | from torch.nn import Module, Parameter, init, Sequential |
14 | 14 | from torch.nn import Conv2d, Linear, BatchNorm1d, BatchNorm2d |
| 15 | +from torch.nn import ConvTranspose2d |
15 | 16 | from complexFunctions import complex_relu, complex_max_pool2d |
16 | 17 | from complexFunctions import complex_dropout, complex_dropout2d |
17 | 18 |
|
@@ -62,6 +63,22 @@ class ComplexReLU(Module): |
62 | 63 | def forward(self,input_r,input_i): |
63 | 64 | return complex_relu(input_r,input_i) |
64 | 65 |
|
| 66 | +class ComplexConvTranspose2d(Module): |
| 67 | + |
| 68 | + def __init__(self,in_channels, out_channels, kernel_size, stride=1, padding=0, |
| 69 | + output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros'): |
| 70 | + |
| 71 | + super(ComplexConvTranspose2d, self).__init__() |
| 72 | + |
| 73 | + self.conv_tran_r = ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, |
| 74 | + output_padding, groups, bias, dilation, padding_mode) |
| 75 | + self.conv_tran_i = ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, |
| 76 | + output_padding, groups, bias, dilation, padding_mode) |
| 77 | + |
| 78 | + |
| 79 | + def forward(self,input_r,input_i): |
| 80 | + return self.conv_tran_r(input_r)-self.conv_tran_i(input_i), \ |
| 81 | + self.conv_tran_r(input_i)+self.conv_tran_i(input_r) |
65 | 82 |
|
66 | 83 | class ComplexConv2d(Module): |
67 | 84 |
|
|
0 commit comments