diff --git a/dorn/modules/decoders/OrdinalRegression.py b/dorn/modules/decoders/OrdinalRegression.py index 85e0ae8..8702926 100644 --- a/dorn/modules/decoders/OrdinalRegression.py +++ b/dorn/modules/decoders/OrdinalRegression.py @@ -53,6 +53,6 @@ def forward(self, x): prob = F.log_softmax(x, dim=1).view(N, C, H, W) return prob - ord_prob = F.softmax(x, dim=1)[:, 0, :, :, :] + ord_prob = F.softmax(x, dim=1).view(N, C, H, W) ord_label = torch.sum((ord_prob > 0.5), dim=1) return ord_prob, ord_label