-
Notifications
You must be signed in to change notification settings - Fork 92
Open
Description
Thank you for providing this reproduction!
I have a question on the grouped convolution: in this line you use the grouped convolution to solve the mini-batch training problem.
Could we use the torch.Tensor.expand
to replace the grouped convolution, like:
weight_prime = weight.expand(K, weight.shape[0], weight.shape[1], weight.shape[2], weight.shape[3])
weight = torch.mm(softmax_attention, weight_prime).view(-1, x.shape[1], self.kernel_size, self.kernel_size)
In this way, we might aggregate the attention weight and the convolution weight together. However, this may cause another problem. If batch size (torch.mean(attention_weight, dim=0)
or torch.max(attention_weight, dim=0)
since they are calculated within the batch, in which the range is very close.
I am not sure whether this calculation is equivalent to the line :)
Metadata
Metadata
Assignees
Labels
No labels