Skip to content

Commit 317a3de

Browse files
[weight_only_linear]decrease code (#72842)
* remove not used code * remove not used code * remove not used code * commit * commit
1 parent 34e8aac commit 317a3de

File tree

1 file changed

+9
-18
lines changed

1 file changed

+9
-18
lines changed

paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,29 +67,20 @@ struct LayoutDetailsB<TypeB, arch::Sm70> {
6767
using Operator = cutlass::arch::OpMultiplyAdd;
6868
};
6969

70-
// Specializations for Turing+ when B is FP16. These are currently only used for
71-
// MoE networks.
72-
template <typename Arch>
73-
struct LayoutDetailsB<
74-
half_t,
75-
Arch,
76-
typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
77-
static constexpr int ThreadblockK = 64;
78-
using Layout = layout::RowMajor;
79-
static constexpr int ElementsPerAccess =
80-
128 / cutlass::sizeof_bits<half_t>::value;
81-
using Operator = cutlass::arch::OpMultiplyAdd;
82-
};
70+
// Specializations for Turing+ when B is 16 bit. These are currently only used
71+
// for MoE networks.
8372

84-
template <typename Arch>
73+
template <typename TypeB, typename Arch>
8574
struct LayoutDetailsB<
86-
bfloat16_t,
75+
TypeB,
8776
Arch,
88-
typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
77+
typename platform::enable_if<
78+
Arch::kMinComputeCapability >= 75 &&
79+
(platform::is_same<TypeB, half_t>::value ||
80+
platform::is_same<TypeB, bfloat16_t>::value)>::type> {
8981
static constexpr int ThreadblockK = 64;
9082
using Layout = layout::RowMajor;
91-
static constexpr int ElementsPerAccess =
92-
128 / cutlass::sizeof_bits<bfloat16_t>::value;
83+
static constexpr int ElementsPerAccess = 8;
9384
using Operator = cutlass::arch::OpMultiplyAdd;
9485
};
9586

0 commit comments

Comments
 (0)