Skip to content

A Question in Batch Size and torch.Tensor.expand #23

@TerenceLiu98

Description

@TerenceLiu98

Thank you for providing this reproduction!

I have a question on the grouped convolution: in this line you use the grouped convolution to solve the mini-batch training problem.

Could we use the torch.Tensor.expand to replace the grouped convolution, like:

weight_prime = weight.expand(K, weight.shape[0], weight.shape[1], weight.shape[2], weight.shape[3])
weight = torch.mm(softmax_attention, weight_prime).view(-1, x.shape[1], self.kernel_size, self.kernel_size)

In this way, we might aggregate the attention weight and the convolution weight together. However, this may cause another problem. If batch size ($\mathcal{B}$) is larger than 1, the attention weight would be a matrix with $\mathcal{B} \times K$, I think we can use torch.mean(attention_weight, dim=0) or torch.max(attention_weight, dim=0) since they are calculated within the batch, in which the range is very close.

I am not sure whether this calculation is equivalent to the line :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions