@@ -1523,87 +1523,12 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm)
1523
1523
bool transA = attrs.get (" transA" , false );
1524
1524
bool transB = attrs.get (" transB" , false );
1525
1525
nvinfer1::ITensor& inputA = convertToTensor (inputs.at (0 ), ctx);
1526
+ nvinfer1::ITensor& inputB = convertToTensor (inputs.at (1 ), ctx);
1526
1527
// Validate inputs
1527
- ASSERT (inputs. at ( 0 ). shape (). nbDims == 2 && inputs. at ( 1 ). shape ().nbDims == 2 && " GEMM must have 2D inputs!" , ErrorCode::kINVALID_NODE );
1528
+ ASSERT (inputA. getDimensions (). nbDims == 2 && inputB. getDimensions ().nbDims == 2 && " GEMM must have 2D inputs!" , ErrorCode::kINVALID_NODE );
1528
1529
// TRT does not support INT32 input types for this node
1529
- ASSERT (!inputs.at (0 ).isInt32 () && !inputs.at (1 ).isInt32 ()
1530
- && " TensorRT doesn't support INT32 inputs for GEMM!" , ErrorCode::kUNSUPPORTED_NODE );
1531
- // Use FC if it is likely to be faster - which is usually when no Shuffles are required.
1532
- bool canUseFC = inputs.at (0 ).is_tensor () && inputs.at (1 ).is_weights () && alpha == 1 .f
1533
- && beta == 1 .f && inputs.at (0 ).tensor ().getDimensions ().nbDims == 2 && inputs.at (1 ).weights ().shape .nbDims == 2 ;
1534
- canUseFC &= inputs.size () < 3 || (inputs.at (2 ).is_weights () && inputs.at (2 ).weights ().shape .nbDims == 1 );
1535
- if (canUseFC)
1536
- {
1537
- LOG_VERBOSE (" GEMM: using FC layer instead of MM because all criteria were met." );
1538
- const std::vector<int > axesInput{2 , 3 };
1539
- nvinfer1::ITensor* inputAExtendDim = unsqueezeTensor (ctx, node, inputA, axesInput);
1540
-
1541
- ShapedWeights weights = inputs.at (1 ).weights ();
1542
- if (!transB)
1543
- {
1544
- auto transposedWeights = ctx->createTempWeights (weights.type , weights.shape );
1545
- ASSERT (transposeWeights (weights, {1 , 0 }, &transposedWeights, ctx), ErrorCode::kUNSUPPORTED_NODE );
1546
- weights = transposedWeights;
1547
- }
1548
- ShapedWeights biases{};
1549
- if (inputs.size () > 2 )
1550
- {
1551
- biases = inputs.at (2 ).weights ();
1552
- }
1553
- nvinfer1::IFullyConnectedLayer* fc = ctx->network ()->addFullyConnected (*inputAExtendDim, biases.shape .d [0 ], weights, biases);
1554
- // Register layer, along with refittable kernel weights and bias weights (if any)
1555
- ctx->registerLayer (fc, getNodeName (node));
1556
- ctx->network ()->setWeightsName (weights, weights.getName ());
1557
- if (inputs.size () == 3 )
1558
- {
1559
- ctx->network ()->setWeightsName (biases, inputs.at (2 ).weights ().getName ());
1560
- }
1561
- const std::vector<int > axesOutput{2 , 3 };
1562
- return {{squeezeTensor (ctx, node, *fc->getOutput (0 ), axesOutput)}};
1563
- }
1564
-
1565
- nvinfer1::ITensor* inputB {nullptr };
1566
-
1567
- // If input B is a constant, we transpose at parse time if necessary,
1568
- // because In some cases, A * Bt is much slower than A * B.
1569
- if (inputs.at (1 ).is_weights ())
1570
- {
1571
- ShapedWeights weights = inputs.at (1 ).weights ();
1572
- if (transB)
1573
- {
1574
- auto transposedWeights = ctx->createTempWeights (weights.type , weights.shape );
1575
- ASSERT (transposeWeights (weights, {1 , 0 }, &transposedWeights, ctx) && " Failed to transpose input tensor B." , ErrorCode::kUNSUPPORTED_NODE );
1576
- weights = transposedWeights;
1577
- // Since we've already transposed now, we can set transpose to false.
1578
- transB = false ;
1579
- }
1580
- nvinfer1::IConstantLayer* weightsLayer
1581
- = ctx->network ()->addConstant (weights.shape , static_cast <nvinfer1::Weights>(weights));
1582
- // Map the constant layer to the weights name.
1583
- ctx->registerLayer (weightsLayer, node.input (1 ));
1584
- ctx->network ()->setWeightsName (weights, weights.getName ());
1585
- inputB = weightsLayer->getOutput (0 );
1586
- }
1587
- else
1588
- {
1589
- inputB = &inputs.at (1 ).tensor ();
1590
- }
1591
-
1592
- nvinfer1::ITensor* inputASqueezed = &inputA;
1593
- nvinfer1::Dims newDims = squeeze_trailing_dims (inputA.getDimensions ());
1594
- // When A has more than 2 dimensions, it needs to be flattened.
1595
- if (newDims.nbDims > 2 )
1596
- {
1597
- newDims = nvinfer1::Dims{1 , {-1 }};
1598
- }
1599
- // Due to other TRT layers, inputA may sometimes have trailing 1s that need to be removed.
1600
- if (newDims.nbDims < inputA.getDimensions ().nbDims )
1601
- {
1602
- nvinfer1::IShuffleLayer* squeeze = ctx->network ()->addShuffle (inputA);
1603
- squeeze->setReshapeDimensions (newDims);
1604
- squeeze->setZeroIsPlaceholder (false );
1605
- inputASqueezed = squeeze->getOutput (0 );
1606
- }
1530
+ ASSERT (!inputs.at (0 ).isInt32 () && !inputs.at (1 ).isInt32 () && " TensorRT doesn't support INT32 inputs for GEMM!" ,
1531
+ ErrorCode::kUNSUPPORTED_NODE );
1607
1532
1608
1533
const auto getMatrixOp = [](const nvinfer1::ITensor& input, bool transpose) {
1609
1534
if (input.getDimensions ().nbDims == 1 )
@@ -1617,13 +1542,12 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm)
1617
1542
return nvinfer1::MatrixOperation::kNONE ;
1618
1543
};
1619
1544
1620
- nvinfer1::MatrixOperation opA = getMatrixOp (*inputASqueezed , transA);
1621
- nvinfer1::MatrixOperation opB = getMatrixOp (* inputB, transB);
1545
+ nvinfer1::MatrixOperation opA = getMatrixOp (inputA , transA);
1546
+ nvinfer1::MatrixOperation opB = getMatrixOp (inputB, transB);
1622
1547
1623
1548
LOG_VERBOSE (" Using opA: " << static_cast <int >(opA) << " opB: " << static_cast <int >(opB));
1624
- LOG_VERBOSE (" GEMM: A, after squeezing: " << inputASqueezed->getDimensions ());
1625
1549
1626
- nvinfer1::IMatrixMultiplyLayer* matmul = ctx->network ()->addMatrixMultiply (*inputASqueezed , opA, * inputB, opB);
1550
+ nvinfer1::IMatrixMultiplyLayer* matmul = ctx->network ()->addMatrixMultiply (inputA , opA, inputB, opB);
1627
1551
ctx->registerLayer (matmul, getNodeName (node));
1628
1552
nvinfer1::ITensor* matmulTensor = matmul->getOutput (0 );
1629
1553
@@ -1655,12 +1579,6 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm)
1655
1579
*betaConstantTensor, *biasTensor, nvinfer1::ElementWiseOperation::kPROD );
1656
1580
biasTensor = scaledBias->getOutput (0 );
1657
1581
}
1658
- // A*B may be lower rank than C in TRT, so need to squeeze C.
1659
- if (ctx->getOpsetVersion () < 7 && !attrs.get (" broadcast" , false ))
1660
- {
1661
- nvinfer1::Dims squeezeDims = squeeze_leading_dims (biasTensor->getDimensions ());
1662
- biasTensor = reshapeTensor (ctx, *biasTensor, squeezeDims);
1663
- }
1664
1582
CHECK (broadcastTensors (ctx, matmulTensor, biasTensor));
1665
1583
nvinfer1::IElementWiseLayer* biasAdd
1666
1584
= ctx->network ()->addElementWise (*matmulTensor, *biasTensor, nvinfer1::ElementWiseOperation::kSUM );
0 commit comments