Skip to content

Commit 431f2d6

Browse files
author
Sebastien Popoff
committed
add ConvTranpose2d
1 parent bfbdf8e commit 431f2d6

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

complexLayers.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
from torch.nn import Module, Parameter, init, Sequential
1414
from torch.nn import Conv2d, Linear, BatchNorm1d, BatchNorm2d
15+
from torch.nn import ConvTranspose2d
1516
from complexFunctions import complex_relu, complex_max_pool2d
1617
from complexFunctions import complex_dropout, complex_dropout2d
1718

@@ -62,6 +63,22 @@ class ComplexReLU(Module):
6263
def forward(self,input_r,input_i):
6364
return complex_relu(input_r,input_i)
6465

66+
class ComplexConvTranspose2d(Module):
67+
68+
def __init__(self,in_channels, out_channels, kernel_size, stride=1, padding=0,
69+
output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros'):
70+
71+
super(ComplexConvTranspose2d, self).__init__()
72+
73+
self.conv_tran_r = ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding,
74+
output_padding, groups, bias, dilation, padding_mode)
75+
self.conv_tran_i = ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding,
76+
output_padding, groups, bias, dilation, padding_mode)
77+
78+
79+
def forward(self,input_r,input_i):
80+
return self.conv_tran_r(input_r)-self.conv_tran_i(input_i), \
81+
self.conv_tran_r(input_i)+self.conv_tran_i(input_r)
6582

6683
class ComplexConv2d(Module):
6784

0 commit comments

Comments
 (0)