Skip to content

Commit 5f27e45

Browse files
authored
Fix GEMM importer (#828)
Signed-off-by: Kevin Chen <kevinch@nvidia.com>
1 parent 0031a42 commit 5f27e45

File tree

2 files changed

+7
-114
lines changed

2 files changed

+7
-114
lines changed

builtin_op_importers.cpp

Lines changed: 7 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,87 +1523,12 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm)
15231523
bool transA = attrs.get("transA", false);
15241524
bool transB = attrs.get("transB", false);
15251525
nvinfer1::ITensor& inputA = convertToTensor(inputs.at(0), ctx);
1526+
nvinfer1::ITensor& inputB = convertToTensor(inputs.at(1), ctx);
15261527
// Validate inputs
1527-
ASSERT(inputs.at(0).shape().nbDims == 2 && inputs.at(1).shape().nbDims == 2 && "GEMM must have 2D inputs!", ErrorCode::kINVALID_NODE);
1528+
ASSERT(inputA.getDimensions().nbDims == 2 && inputB.getDimensions().nbDims == 2 && "GEMM must have 2D inputs!", ErrorCode::kINVALID_NODE);
15281529
// TRT does not support INT32 input types for this node
1529-
ASSERT(!inputs.at(0).isInt32() && !inputs.at(1).isInt32()
1530-
&& "TensorRT doesn't support INT32 inputs for GEMM!", ErrorCode::kUNSUPPORTED_NODE);
1531-
// Use FC if it is likely to be faster - which is usually when no Shuffles are required.
1532-
bool canUseFC = inputs.at(0).is_tensor() && inputs.at(1).is_weights() && alpha == 1.f
1533-
&& beta == 1.f && inputs.at(0).tensor().getDimensions().nbDims == 2 && inputs.at(1).weights().shape.nbDims == 2;
1534-
canUseFC &= inputs.size() < 3 || (inputs.at(2).is_weights() && inputs.at(2).weights().shape.nbDims == 1);
1535-
if (canUseFC)
1536-
{
1537-
LOG_VERBOSE("GEMM: using FC layer instead of MM because all criteria were met.");
1538-
const std::vector<int> axesInput{2, 3};
1539-
nvinfer1::ITensor* inputAExtendDim = unsqueezeTensor(ctx, node, inputA, axesInput);
1540-
1541-
ShapedWeights weights = inputs.at(1).weights();
1542-
if (!transB)
1543-
{
1544-
auto transposedWeights = ctx->createTempWeights(weights.type, weights.shape);
1545-
ASSERT(transposeWeights(weights, {1, 0}, &transposedWeights, ctx), ErrorCode::kUNSUPPORTED_NODE);
1546-
weights = transposedWeights;
1547-
}
1548-
ShapedWeights biases{};
1549-
if (inputs.size() > 2)
1550-
{
1551-
biases = inputs.at(2).weights();
1552-
}
1553-
nvinfer1::IFullyConnectedLayer* fc = ctx->network()->addFullyConnected(*inputAExtendDim, biases.shape.d[0], weights, biases);
1554-
// Register layer, along with refittable kernel weights and bias weights (if any)
1555-
ctx->registerLayer(fc, getNodeName(node));
1556-
ctx->network()->setWeightsName(weights, weights.getName());
1557-
if (inputs.size() == 3)
1558-
{
1559-
ctx->network()->setWeightsName(biases, inputs.at(2).weights().getName());
1560-
}
1561-
const std::vector<int> axesOutput{2, 3};
1562-
return {{squeezeTensor(ctx, node, *fc->getOutput(0), axesOutput)}};
1563-
}
1564-
1565-
nvinfer1::ITensor* inputB {nullptr};
1566-
1567-
// If input B is a constant, we transpose at parse time if necessary,
1568-
// because In some cases, A * Bt is much slower than A * B.
1569-
if (inputs.at(1).is_weights())
1570-
{
1571-
ShapedWeights weights = inputs.at(1).weights();
1572-
if (transB)
1573-
{
1574-
auto transposedWeights = ctx->createTempWeights(weights.type, weights.shape);
1575-
ASSERT(transposeWeights(weights, {1, 0}, &transposedWeights, ctx) && "Failed to transpose input tensor B.", ErrorCode::kUNSUPPORTED_NODE);
1576-
weights = transposedWeights;
1577-
// Since we've already transposed now, we can set transpose to false.
1578-
transB = false;
1579-
}
1580-
nvinfer1::IConstantLayer* weightsLayer
1581-
= ctx->network()->addConstant(weights.shape, static_cast<nvinfer1::Weights>(weights));
1582-
// Map the constant layer to the weights name.
1583-
ctx->registerLayer(weightsLayer, node.input(1));
1584-
ctx->network()->setWeightsName(weights, weights.getName());
1585-
inputB = weightsLayer->getOutput(0);
1586-
}
1587-
else
1588-
{
1589-
inputB = &inputs.at(1).tensor();
1590-
}
1591-
1592-
nvinfer1::ITensor* inputASqueezed = &inputA;
1593-
nvinfer1::Dims newDims = squeeze_trailing_dims(inputA.getDimensions());
1594-
// When A has more than 2 dimensions, it needs to be flattened.
1595-
if (newDims.nbDims > 2)
1596-
{
1597-
newDims = nvinfer1::Dims{1, {-1}};
1598-
}
1599-
// Due to other TRT layers, inputA may sometimes have trailing 1s that need to be removed.
1600-
if (newDims.nbDims < inputA.getDimensions().nbDims)
1601-
{
1602-
nvinfer1::IShuffleLayer* squeeze = ctx->network()->addShuffle(inputA);
1603-
squeeze->setReshapeDimensions(newDims);
1604-
squeeze->setZeroIsPlaceholder(false);
1605-
inputASqueezed = squeeze->getOutput(0);
1606-
}
1530+
ASSERT(!inputs.at(0).isInt32() && !inputs.at(1).isInt32() && "TensorRT doesn't support INT32 inputs for GEMM!",
1531+
ErrorCode::kUNSUPPORTED_NODE);
16071532

16081533
const auto getMatrixOp = [](const nvinfer1::ITensor& input, bool transpose) {
16091534
if (input.getDimensions().nbDims == 1)
@@ -1617,13 +1542,12 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm)
16171542
return nvinfer1::MatrixOperation::kNONE;
16181543
};
16191544

1620-
nvinfer1::MatrixOperation opA = getMatrixOp(*inputASqueezed, transA);
1621-
nvinfer1::MatrixOperation opB = getMatrixOp(*inputB, transB);
1545+
nvinfer1::MatrixOperation opA = getMatrixOp(inputA, transA);
1546+
nvinfer1::MatrixOperation opB = getMatrixOp(inputB, transB);
16221547

16231548
LOG_VERBOSE("Using opA: " << static_cast<int>(opA) << " opB: " << static_cast<int>(opB));
1624-
LOG_VERBOSE("GEMM: A, after squeezing: " << inputASqueezed->getDimensions());
16251549

1626-
nvinfer1::IMatrixMultiplyLayer* matmul = ctx->network()->addMatrixMultiply(*inputASqueezed, opA, *inputB, opB);
1550+
nvinfer1::IMatrixMultiplyLayer* matmul = ctx->network()->addMatrixMultiply(inputA, opA, inputB, opB);
16271551
ctx->registerLayer(matmul, getNodeName(node));
16281552
nvinfer1::ITensor* matmulTensor = matmul->getOutput(0);
16291553

@@ -1655,12 +1579,6 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm)
16551579
*betaConstantTensor, *biasTensor, nvinfer1::ElementWiseOperation::kPROD);
16561580
biasTensor = scaledBias->getOutput(0);
16571581
}
1658-
// A*B may be lower rank than C in TRT, so need to squeeze C.
1659-
if (ctx->getOpsetVersion() < 7 && !attrs.get("broadcast", false))
1660-
{
1661-
nvinfer1::Dims squeezeDims = squeeze_leading_dims(biasTensor->getDimensions());
1662-
biasTensor = reshapeTensor(ctx, *biasTensor, squeezeDims);
1663-
}
16641582
CHECK(broadcastTensors(ctx, matmulTensor, biasTensor));
16651583
nvinfer1::IElementWiseLayer* biasAdd
16661584
= ctx->network()->addElementWise(*matmulTensor, *biasTensor, nvinfer1::ElementWiseOperation::kSUM);

trt_utils.hpp

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -102,31 +102,6 @@ inline nvinfer1::Permutation remove_first_dim(nvinfer1::Permutation const& perm)
102102
return new_perm;
103103
}
104104

105-
inline nvinfer1::Dims squeeze_trailing_dims(nvinfer1::Dims const& dims)
106-
{
107-
nvinfer1::Dims new_dims = dims;
108-
// Note: TRT requires at least one dimension, so we don't squeeze [1]->[]
109-
while (new_dims.nbDims > 1 && new_dims.d[new_dims.nbDims - 1] == 1)
110-
{
111-
--new_dims.nbDims;
112-
}
113-
return new_dims;
114-
}
115-
116-
inline nvinfer1::Dims squeeze_leading_dims(const nvinfer1::Dims& dims)
117-
{
118-
nvinfer1::Dims newDims;
119-
// Copy dims only if a non-1 has been seen already.
120-
bool non1Seen{false};
121-
newDims.nbDims = std::copy_if(dims.d, dims.d + dims.nbDims, newDims.d,
122-
[&non1Seen](int x) {
123-
non1Seen = (x != 1) ? true : non1Seen;
124-
return non1Seen;
125-
})
126-
- newDims.d;
127-
return newDims;
128-
}
129-
130105
inline nvinfer1::DimsHW operator-(nvinfer1::DimsHW dims)
131106
{
132107
return nvinfer1::DimsHW(-dims.h(), -dims.w());

0 commit comments

Comments
 (0)