@@ -108,21 +108,25 @@ NodeImportResult argMinMaxHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::
108
108
// We don't care about the TopK values, just the indices.
109
109
nvinfer1::ITensor* indices = layer->getOutput (1 );
110
110
indices->setType (nvinfer1::DataType::kINT32 );
111
+
112
+ // If selectLastIndex is true, the TopK operation was performed on reversed data on the provided axis.
113
+ // Convert reversed indices back to forward indices by calculating the following:
114
+ // indices = shape(tensor)[axis] - indices - 1
111
115
if (selectLastIndex)
112
116
{
113
- nvinfer1::Dims dims = tensor.getDimensions ();
114
- int dimOnAxis = dims.d [axis];
115
- nvinfer1::Dims resultShape (dims);
116
- resultShape.d [axis] = 1 ;
117
- ShapedWeights shapeWeights = ctx->createTempWeights (::ONNX_NAMESPACE::TensorProto::INT32, resultShape);
118
- std::vector<int > tempData (shapeWeights.count (), dimOnAxis);
119
- std::memcpy (shapeWeights.values , tempData.data (), shapeWeights.count () * sizeof (int ));
117
+ // Use shapeTensor semantics to support dynamic shapes
118
+ auto const dims = shapeOf (tensor);
119
+ auto const indicesDims = shapeOf (*indices);
120
+ auto const axisTensor = shapeVector (axis);
121
+ auto const dimOnAxis = gather (ctx, dims, axisTensor);
122
+
123
+ // Create constant of shape indicesDims with values tensor.shape[axis]
124
+ auto const tensorDimOnAxis = constantOfShape (ctx, node, &dimOnAxis.tensor (ctx), &indicesDims.tensor (ctx));
120
125
121
- ShapedWeights weightOfOnes = ctx->createTempWeights (::ONNX_NAMESPACE::TensorProto::INT32, resultShape);
122
- std::vector<int > ones (shapeWeights.count (), 1 );
123
- std::memcpy (weightOfOnes.values , ones.data (), weightOfOnes.count () * sizeof (int ));
126
+ // Create constant of shape indicesDims with values of 1
127
+ auto const ones = constantOfShape (ctx, node, &shapeVector (1 ).tensor (ctx), &indicesDims.tensor (ctx));
124
128
125
- std::vector<TensorOrWeights> newInputs{shapeWeights , indices, weightOfOnes };
129
+ std::vector<TensorOrWeights> newInputs{tensorDimOnAxis , indices, ones };
126
130
indices = &elementwiseHelper (ctx, node, newInputs, nvinfer1::ElementWiseOperation::kSUB ).value ().at (0 ).tensor ();
127
131
}
128
132
if (keepdims)
0 commit comments