@@ -67,29 +67,20 @@ struct LayoutDetailsB<TypeB, arch::Sm70> {
67
67
using Operator = cutlass::arch::OpMultiplyAdd;
68
68
};
69
69
70
- // Specializations for Turing+ when B is FP16. These are currently only used for
71
- // MoE networks.
72
- template <typename Arch>
73
- struct LayoutDetailsB <
74
- half_t ,
75
- Arch,
76
- typename platform::enable_if<Arch::kMinComputeCapability >= 75 >::type> {
77
- static constexpr int ThreadblockK = 64 ;
78
- using Layout = layout::RowMajor;
79
- static constexpr int ElementsPerAccess =
80
- 128 / cutlass::sizeof_bits<half_t >::value;
81
- using Operator = cutlass::arch::OpMultiplyAdd;
82
- };
70
+ // Specializations for Turing+ when B is 16 bit. These are currently only used
71
+ // for MoE networks.
83
72
84
- template <typename Arch>
73
+ template <typename TypeB, typename Arch>
85
74
struct LayoutDetailsB <
86
- bfloat16_t ,
75
+ TypeB ,
87
76
Arch,
88
- typename platform::enable_if<Arch::kMinComputeCapability >= 75 >::type> {
77
+ typename platform::enable_if<
78
+ Arch::kMinComputeCapability >= 75 &&
79
+ (platform::is_same<TypeB, half_t >::value ||
80
+ platform::is_same<TypeB, bfloat16_t >::value)>::type> {
89
81
static constexpr int ThreadblockK = 64 ;
90
82
using Layout = layout::RowMajor;
91
- static constexpr int ElementsPerAccess =
92
- 128 / cutlass::sizeof_bits<bfloat16_t >::value;
83
+ static constexpr int ElementsPerAccess = 8 ;
93
84
using Operator = cutlass::arch::OpMultiplyAdd;
94
85
};
95
86
0 commit comments