20
20
21
21
#include " zero.cc"
22
22
23
- template <typename T_in, typename T_out, int rowA, int colA, int colB>
23
+ template <typename T_in, typename T_out, int rowA, int colA, int colB,
24
+ bool b_row_maj = true , bool c_row_maj = true >
24
25
static inline void matmul_scalar (T_in *a, T_in *b, T_out *c) {
25
26
event0 ();
26
27
for (int row = 0 ; row < rowA; row++) {
27
28
for (int col = 0 ; col < colB; col++) {
28
29
T_out running_sum = 0 ;
29
30
for (int i = 0 ; i < colA; i++) {
30
- running_sum += a[row * colA + i] * b[i * colB + col];
31
+ T_in a_val = a[row * colA + i];
32
+ T_in b_val;
33
+ if constexpr (b_row_maj) {
34
+ b_val = b[i * colB + col];
35
+ } else {
36
+ b_val = b[i + col * colA];
37
+ }
38
+ running_sum += a_val * b_val;
39
+ }
40
+ T_out *c_ptr;
41
+ if constexpr (c_row_maj) {
42
+ c_ptr = &c[row * colB + col];
43
+ } else {
44
+ c_ptr = &c[row + col * rowA];
31
45
}
32
- c[row * colB + col] += running_sum;
46
+ *c_ptr += running_sum;
33
47
}
34
48
}
35
49
event1 ();
@@ -65,7 +79,7 @@ static inline void matmul_scalar(T_in *a, T_in *b, T_out *c) {
65
79
*/
66
80
template <typename T_in, typename T_out, unsigned rowA, unsigned colA,
67
81
unsigned colB, unsigned r, unsigned s, unsigned t,
68
- bool b_row_maj = true >
82
+ bool b_row_maj = true , bool c_row_maj = true >
69
83
static inline void matmul_vectorized_2x2_mmul (const T_in *__restrict pA,
70
84
const T_in *__restrict pB,
71
85
T_out *__restrict pC) {
@@ -76,14 +90,24 @@ static inline void matmul_vectorized_2x2_mmul(const T_in *__restrict pA,
76
90
77
91
for (unsigned z = 0 ; z < rowA; z += 2 )
78
92
chess_prepare_for_pipelining chess_loop_range (4 , ) {
79
- T_out *__restrict pC1 = pC + (z * colB) * MMUL::size_C;
80
- T_out *__restrict pC2 = pC + ((z + 1 ) * colB) * MMUL::size_C;
93
+
94
+ T_out *__restrict pC1;
95
+ T_out *__restrict pC2;
96
+ if constexpr (c_row_maj) {
97
+ pC1 = pC + (z * colB) * MMUL::size_C;
98
+ pC2 = pC + ((z + 1 ) * colB) * MMUL::size_C;
99
+ }
81
100
82
101
for (unsigned j = 0 ; j < colB; j += 2 )
83
102
#ifdef OPT_PERF_ENABLED
84
103
chess_flatten_loop
85
104
#endif
86
105
{
106
+
107
+ if constexpr (!c_row_maj) {
108
+ pC1 = pC + j * rowA * MMUL::size_C + z * MMUL::size_C;
109
+ pC2 = pC + (j + 1 ) * rowA * MMUL::size_C + z * MMUL::size_C;
110
+ }
87
111
const T_in *__restrict pA1 = pA + (z * colA) * MMUL::size_A;
88
112
const T_in *__restrict pA2 = pA + ((z + 1 ) * colA) * MMUL::size_A;
89
113
const T_in *__restrict pB1;
@@ -95,6 +119,7 @@ static inline void matmul_vectorized_2x2_mmul(const T_in *__restrict pA,
95
119
pB1 = pB + (j * colA) * MMUL::size_B;
96
120
pB2 = pB + ((j + 1 ) * colA) * MMUL::size_B;
97
121
}
122
+
98
123
aie::vector<T_in, MMUL::size_A> A0;
99
124
aie::vector<T_in, MMUL::size_A> A1;
100
125
aie::vector<T_in, MMUL::size_B> B0;
@@ -103,14 +128,23 @@ static inline void matmul_vectorized_2x2_mmul(const T_in *__restrict pA,
103
128
// Load partial results from C buffer for accumulation in-place. The
104
129
// zero.cc function handles the zeroing of data when a new
105
130
// accumulation is needed (after the 'K' reduction dimension)
106
- aie::vector<T_out, MMUL::size_C> acc_C00 =
107
- aie::load_v<MMUL::size_C>(pC1);
108
- aie::vector<T_out, MMUL::size_C> acc_C01 =
109
- aie::load_v<MMUL::size_C>(pC1 + MMUL::size_C);
110
- aie::vector<T_out, MMUL::size_C> acc_C10 =
111
- aie::load_v<MMUL::size_C>(pC2);
112
- aie::vector<T_out, MMUL::size_C> acc_C11 =
113
- aie::load_v<MMUL::size_C>(pC2 + MMUL::size_C);
131
+ aie::vector<T_out, MMUL::size_C> acc_C00;
132
+ aie::vector<T_out, MMUL::size_C> acc_C01;
133
+ aie::vector<T_out, MMUL::size_C> acc_C10;
134
+ aie::vector<T_out, MMUL::size_C> acc_C11;
135
+ if constexpr (c_row_maj) {
136
+ acc_C00 = aie::load_v<MMUL::size_C>(pC1);
137
+ acc_C01 = aie::load_v<MMUL::size_C>(pC1 + MMUL::size_C);
138
+ acc_C10 = aie::load_v<MMUL::size_C>(pC2);
139
+ acc_C11 = aie::load_v<MMUL::size_C>(pC2 + MMUL::size_C);
140
+ } else {
141
+ acc_C00 = aie::transpose (aie::load_v<MMUL::size_C>(pC1), t, r);
142
+ acc_C01 = aie::transpose (aie::load_v<MMUL::size_C>(pC2), t, r);
143
+ acc_C10 = aie::transpose (
144
+ aie::load_v<MMUL::size_C>(pC1 + MMUL::size_C), t, r);
145
+ acc_C11 = aie::transpose (
146
+ aie::load_v<MMUL::size_C>(pC2 + MMUL::size_C), t, r);
147
+ }
114
148
115
149
MMUL C00 (acc_C00);
116
150
MMUL C01 (acc_C01);
@@ -149,14 +183,30 @@ static inline void matmul_vectorized_2x2_mmul(const T_in *__restrict pA,
149
183
// example below shows how to shift right 10 bits
150
184
// #define SHIFT 10
151
185
// aie::store_v(pC1, C00.template to_vector<T_out>(SHIFT));
152
- aie::store_v (pC1, C00.template to_vector <T_out>());
153
- pC1 += MMUL::size_C;
154
- aie::store_v (pC1, C01.template to_vector <T_out>());
155
- pC1 += MMUL::size_C;
156
- aie::store_v (pC2, C10.template to_vector <T_out>());
157
- pC2 += MMUL::size_C;
158
- aie::store_v (pC2, C11.template to_vector <T_out>());
159
- pC2 += MMUL::size_C;
186
+
187
+ if constexpr (c_row_maj) {
188
+ aie::store_v (pC1, C00.template to_vector <T_out>());
189
+ pC1 += MMUL::size_C;
190
+ aie::store_v (pC1, C01.template to_vector <T_out>());
191
+ pC1 += MMUL::size_C;
192
+ aie::store_v (pC2, C10.template to_vector <T_out>());
193
+ pC2 += MMUL::size_C;
194
+ aie::store_v (pC2, C11.template to_vector <T_out>());
195
+ pC2 += MMUL::size_C;
196
+ } else {
197
+ aie::store_v (pC1,
198
+ aie::transpose (C00.template to_vector <T_out>(), r, t));
199
+ pC1 += MMUL::size_C;
200
+ aie::store_v (pC2,
201
+ aie::transpose (C01.template to_vector <T_out>(), r, t));
202
+ pC2 += MMUL::size_C;
203
+ aie::store_v (pC1,
204
+ aie::transpose (C10.template to_vector <T_out>(), r, t));
205
+ pC1 += MMUL::size_C;
206
+ aie::store_v (pC2,
207
+ aie::transpose (C11.template to_vector <T_out>(), r, t));
208
+ pC2 += MMUL::size_C;
209
+ }
160
210
}
161
211
}
162
212
@@ -169,6 +219,12 @@ constexpr bool is_b_row_maj = false;
169
219
constexpr bool is_b_row_maj = true ;
170
220
#endif
171
221
222
+ #ifdef C_COL_MAJ
223
+ constexpr bool is_c_row_maj = false ;
224
+ #else
225
+ constexpr bool is_c_row_maj = true ;
226
+ #endif
227
+
172
228
// The following kernel definitions use mmul shapes that have been found to be
173
229
// optimal for AIE2P in combination with the 2x2 mmul expanded kernel.
174
230
//
@@ -195,7 +251,8 @@ static inline void matmul_vectorized_4x4x8_i16_i16(const int16 *__restrict pA,
195
251
static_assert (n % (2 * t) == 0 );
196
252
197
253
return matmul_vectorized_2x2_mmul<int16, int16, (m / r), (k / s), (n / t), r,
198
- s, t, is_b_row_maj>(pA, pB, pC);
254
+ s, t, is_b_row_maj, is_c_row_maj>(pA, pB,
255
+ pC);
199
256
}
200
257
201
258
template <unsigned m, unsigned k, unsigned n>
@@ -211,7 +268,8 @@ static inline void matmul_vectorized_4x4x8_i16_i32(const int16 *__restrict pA,
211
268
static_assert (n % (2 * t) == 0 );
212
269
213
270
return matmul_vectorized_2x2_mmul<int16, int32, (m / r), (k / s), (n / t), r,
214
- s, t, is_b_row_maj>(pA, pB, pC);
271
+ s, t, is_b_row_maj, is_c_row_maj>(pA, pB,
272
+ pC);
215
273
}
216
274
217
275
template <unsigned m, unsigned k, unsigned n>
@@ -228,7 +286,8 @@ matmul_vectorized_4x8x8_bf16_bf16(const bfloat16 *__restrict pA,
228
286
static_assert (n % (2 * t) == 0 );
229
287
230
288
return matmul_vectorized_2x2_mmul<bfloat16, bfloat16, (m / r), (k / s),
231
- (n / t), r, s, t, is_b_row_maj>(pA, pB, pC);
289
+ (n / t), r, s, t, is_b_row_maj,
290
+ is_c_row_maj>(pA, pB, pC);
232
291
}
233
292
234
293
// Note that this shape is only possible for bf16 when using bfp16 emulation
@@ -247,7 +306,8 @@ matmul_vectorized_8x8x8_bf16_bf16(const bfloat16 *__restrict pA,
247
306
static_assert (n % (2 * t) == 0 );
248
307
249
308
return matmul_vectorized_2x2_mmul<bfloat16, bfloat16, (m / r), (k / s),
250
- (n / t), r, s, t, is_b_row_maj>(pA, pB, pC);
309
+ (n / t), r, s, t, is_b_row_maj,
310
+ is_c_row_maj>(pA, pB, pC);
251
311
}
252
312
253
313
template <unsigned m, unsigned k, unsigned n>
@@ -264,7 +324,8 @@ matmul_vectorized_4x8x8_bf16_f32(const bfloat16 *__restrict pA,
264
324
static_assert (n % (2 * t) == 0 );
265
325
266
326
return matmul_vectorized_2x2_mmul<bfloat16, float , (m / r), (k / s), (n / t),
267
- r, s, t, is_b_row_maj>(pA, pB, pC);
327
+ r, s, t, is_b_row_maj, is_c_row_maj>(pA, pB,
328
+ pC);
268
329
}
269
330
270
331
template <unsigned m, unsigned k, unsigned n>
@@ -281,7 +342,7 @@ matmul_vectorized_8x8x8_bf16_f32(const bfloat16 *__restrict pA,
281
342
static_assert (n % (2 * t) == 0 );
282
343
283
344
return matmul_vectorized_2x2_mmul<bfloat16, float , (m / r), (k / s), (n / t),
284
- r, s, t, is_b_row_maj>(pA, pB, pC);
345
+ r, s, t, is_b_row_maj, is_c_row_maj >(pA, pB, pC);
285
346
}
286
347
287
348
template <unsigned m, unsigned k, unsigned n>
@@ -297,7 +358,7 @@ static inline void matmul_vectorized_8x8x8_i8_i8(const int8 *__restrict pA,
297
358
static_assert (n % (2 * t) == 0 );
298
359
299
360
return matmul_vectorized_2x2_mmul<int8, int8, (m / r), (k / s), (n / t), r, s,
300
- t, is_b_row_maj>(pA, pB, pC);
361
+ t, is_b_row_maj, is_c_row_maj >(pA, pB, pC);
301
362
}
302
363
303
364
template <unsigned m, unsigned k, unsigned n>
@@ -313,7 +374,8 @@ static inline void matmul_vectorized_8x8x8_i8_i16(const int8 *__restrict pA,
313
374
static_assert (n % (2 * t) == 0 );
314
375
315
376
return matmul_vectorized_2x2_mmul<int8, int16, (m / r), (k / s), (n / t), r,
316
- s, t, is_b_row_maj>(pA, pB, pC);
377
+ s, t, is_b_row_maj, is_c_row_maj>(pA, pB,
378
+ pC);
317
379
}
318
380
319
381
template <unsigned m, unsigned k, unsigned n>
@@ -329,7 +391,8 @@ static inline void matmul_vectorized_8x8x8_i8_i32(const int8 *__restrict pA,
329
391
static_assert (n % (2 * t) == 0 );
330
392
331
393
return matmul_vectorized_2x2_mmul<int8, int32, (m / r), (k / s), (n / t), r,
332
- s, t, is_b_row_maj>(pA, pB, pC);
394
+ s, t, is_b_row_maj, is_c_row_maj>(pA, pB,
395
+ pC);
333
396
}
334
397
335
398
extern " C" {
@@ -418,7 +481,7 @@ extern "C" {
418
481
r, s, t) \
419
482
void matmul_scalar_##mlir_type_in##_##mlir_type_out( \
420
483
ctype_in *a_in, ctype_in *b_in, ctype_out *c_out) { \
421
- matmul_scalar<ctype_in, ctype_out, DIM_M, DIM_K, DIM_N>(a_in, b_in, \
484
+ matmul_scalar<ctype_in, ctype_out, DIM_M, DIM_K, DIM_N, is_b_row_maj, is_c_row_maj >(a_in, b_in, \
422
485
c_out); \
423
486
}
424
487
0 commit comments