Skip to content

Commit 0031a42

Browse files
authored
Fix dynamic argmax/min when select_last_index is set (#827)
Signed-off-by: Kevin Chen <kevinch@nvidia.com>
1 parent b4a7461 commit 0031a42

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

onnx2trt_utils.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -108,21 +108,25 @@ NodeImportResult argMinMaxHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::
108108
// We don't care about the TopK values, just the indices.
109109
nvinfer1::ITensor* indices = layer->getOutput(1);
110110
indices->setType(nvinfer1::DataType::kINT32);
111+
112+
// If selectLastIndex is true, the TopK operation was performed on reversed data on the provided axis.
113+
// Convert reversed indices back to forward indices by calculating the following:
114+
// indices = shape(tensor)[axis] - indices - 1
111115
if (selectLastIndex)
112116
{
113-
nvinfer1::Dims dims = tensor.getDimensions();
114-
int dimOnAxis = dims.d[axis];
115-
nvinfer1::Dims resultShape(dims);
116-
resultShape.d[axis] = 1;
117-
ShapedWeights shapeWeights = ctx->createTempWeights(::ONNX_NAMESPACE::TensorProto::INT32, resultShape);
118-
std::vector<int> tempData(shapeWeights.count(), dimOnAxis);
119-
std::memcpy(shapeWeights.values, tempData.data(), shapeWeights.count() * sizeof(int));
117+
// Use shapeTensor semantics to support dynamic shapes
118+
auto const dims = shapeOf(tensor);
119+
auto const indicesDims = shapeOf(*indices);
120+
auto const axisTensor = shapeVector(axis);
121+
auto const dimOnAxis = gather(ctx, dims, axisTensor);
122+
123+
// Create constant of shape indicesDims with values tensor.shape[axis]
124+
auto const tensorDimOnAxis = constantOfShape(ctx, node, &dimOnAxis.tensor(ctx), &indicesDims.tensor(ctx));
120125

121-
ShapedWeights weightOfOnes = ctx->createTempWeights(::ONNX_NAMESPACE::TensorProto::INT32, resultShape);
122-
std::vector<int> ones(shapeWeights.count(), 1);
123-
std::memcpy(weightOfOnes.values, ones.data(), weightOfOnes.count() * sizeof(int));
126+
// Create constant of shape indicesDims with values of 1
127+
auto const ones = constantOfShape(ctx, node, &shapeVector(1).tensor(ctx), &indicesDims.tensor(ctx));
124128

125-
std::vector<TensorOrWeights> newInputs{shapeWeights, indices, weightOfOnes};
129+
std::vector<TensorOrWeights> newInputs{tensorDimOnAxis, indices, ones};
126130
indices = &elementwiseHelper(ctx, node, newInputs, nvinfer1::ElementWiseOperation::kSUB).value().at(0).tensor();
127131
}
128132
if (keepdims)

0 commit comments

Comments
 (0)