Skip to content

Commit 366fd0e

Browse files
authored
[GEMM] Add option for column-major C output (#2387)
1 parent fd2c137 commit 366fd0e

16 files changed

+590
-208
lines changed

aie_kernels/aie2/mm.cc

Lines changed: 363 additions & 144 deletions
Large diffs are not rendered by default.

aie_kernels/aie2p/mm.cc

Lines changed: 95 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,30 @@
2020

2121
#include "zero.cc"
2222

23-
template <typename T_in, typename T_out, int rowA, int colA, int colB>
23+
template <typename T_in, typename T_out, int rowA, int colA, int colB,
24+
bool b_row_maj = true, bool c_row_maj = true>
2425
static inline void matmul_scalar(T_in *a, T_in *b, T_out *c) {
2526
event0();
2627
for (int row = 0; row < rowA; row++) {
2728
for (int col = 0; col < colB; col++) {
2829
T_out running_sum = 0;
2930
for (int i = 0; i < colA; i++) {
30-
running_sum += a[row * colA + i] * b[i * colB + col];
31+
T_in a_val = a[row * colA + i];
32+
T_in b_val;
33+
if constexpr (b_row_maj) {
34+
b_val = b[i * colB + col];
35+
} else {
36+
b_val = b[i + col * colA];
37+
}
38+
running_sum += a_val * b_val;
39+
}
40+
T_out *c_ptr;
41+
if constexpr (c_row_maj) {
42+
c_ptr = &c[row * colB + col];
43+
} else {
44+
c_ptr = &c[row + col * rowA];
3145
}
32-
c[row * colB + col] += running_sum;
46+
*c_ptr += running_sum;
3347
}
3448
}
3549
event1();
@@ -65,7 +79,7 @@ static inline void matmul_scalar(T_in *a, T_in *b, T_out *c) {
6579
*/
6680
template <typename T_in, typename T_out, unsigned rowA, unsigned colA,
6781
unsigned colB, unsigned r, unsigned s, unsigned t,
68-
bool b_row_maj = true>
82+
bool b_row_maj = true, bool c_row_maj = true>
6983
static inline void matmul_vectorized_2x2_mmul(const T_in *__restrict pA,
7084
const T_in *__restrict pB,
7185
T_out *__restrict pC) {
@@ -76,14 +90,24 @@ static inline void matmul_vectorized_2x2_mmul(const T_in *__restrict pA,
7690

7791
for (unsigned z = 0; z < rowA; z += 2)
7892
chess_prepare_for_pipelining chess_loop_range(4, ) {
79-
T_out *__restrict pC1 = pC + (z * colB) * MMUL::size_C;
80-
T_out *__restrict pC2 = pC + ((z + 1) * colB) * MMUL::size_C;
93+
94+
T_out *__restrict pC1;
95+
T_out *__restrict pC2;
96+
if constexpr (c_row_maj) {
97+
pC1 = pC + (z * colB) * MMUL::size_C;
98+
pC2 = pC + ((z + 1) * colB) * MMUL::size_C;
99+
}
81100

82101
for (unsigned j = 0; j < colB; j += 2)
83102
#ifdef OPT_PERF_ENABLED
84103
chess_flatten_loop
85104
#endif
86105
{
106+
107+
if constexpr (!c_row_maj) {
108+
pC1 = pC + j * rowA * MMUL::size_C + z * MMUL::size_C;
109+
pC2 = pC + (j + 1) * rowA * MMUL::size_C + z * MMUL::size_C;
110+
}
87111
const T_in *__restrict pA1 = pA + (z * colA) * MMUL::size_A;
88112
const T_in *__restrict pA2 = pA + ((z + 1) * colA) * MMUL::size_A;
89113
const T_in *__restrict pB1;
@@ -95,6 +119,7 @@ static inline void matmul_vectorized_2x2_mmul(const T_in *__restrict pA,
95119
pB1 = pB + (j * colA) * MMUL::size_B;
96120
pB2 = pB + ((j + 1) * colA) * MMUL::size_B;
97121
}
122+
98123
aie::vector<T_in, MMUL::size_A> A0;
99124
aie::vector<T_in, MMUL::size_A> A1;
100125
aie::vector<T_in, MMUL::size_B> B0;
@@ -103,14 +128,23 @@ static inline void matmul_vectorized_2x2_mmul(const T_in *__restrict pA,
103128
// Load partial results from C buffer for accumulation in-place. The
104129
// zero.cc function handles the zeroing of data when a new
105130
// accumulation is needed (after the 'K' reduction dimension)
106-
aie::vector<T_out, MMUL::size_C> acc_C00 =
107-
aie::load_v<MMUL::size_C>(pC1);
108-
aie::vector<T_out, MMUL::size_C> acc_C01 =
109-
aie::load_v<MMUL::size_C>(pC1 + MMUL::size_C);
110-
aie::vector<T_out, MMUL::size_C> acc_C10 =
111-
aie::load_v<MMUL::size_C>(pC2);
112-
aie::vector<T_out, MMUL::size_C> acc_C11 =
113-
aie::load_v<MMUL::size_C>(pC2 + MMUL::size_C);
131+
aie::vector<T_out, MMUL::size_C> acc_C00;
132+
aie::vector<T_out, MMUL::size_C> acc_C01;
133+
aie::vector<T_out, MMUL::size_C> acc_C10;
134+
aie::vector<T_out, MMUL::size_C> acc_C11;
135+
if constexpr (c_row_maj) {
136+
acc_C00 = aie::load_v<MMUL::size_C>(pC1);
137+
acc_C01 = aie::load_v<MMUL::size_C>(pC1 + MMUL::size_C);
138+
acc_C10 = aie::load_v<MMUL::size_C>(pC2);
139+
acc_C11 = aie::load_v<MMUL::size_C>(pC2 + MMUL::size_C);
140+
} else {
141+
acc_C00 = aie::transpose(aie::load_v<MMUL::size_C>(pC1), t, r);
142+
acc_C01 = aie::transpose(aie::load_v<MMUL::size_C>(pC2), t, r);
143+
acc_C10 = aie::transpose(
144+
aie::load_v<MMUL::size_C>(pC1 + MMUL::size_C), t, r);
145+
acc_C11 = aie::transpose(
146+
aie::load_v<MMUL::size_C>(pC2 + MMUL::size_C), t, r);
147+
}
114148

115149
MMUL C00(acc_C00);
116150
MMUL C01(acc_C01);
@@ -149,14 +183,30 @@ static inline void matmul_vectorized_2x2_mmul(const T_in *__restrict pA,
149183
// example below shows how to shift right 10 bits
150184
// #define SHIFT 10
151185
// aie::store_v(pC1, C00.template to_vector<T_out>(SHIFT));
152-
aie::store_v(pC1, C00.template to_vector<T_out>());
153-
pC1 += MMUL::size_C;
154-
aie::store_v(pC1, C01.template to_vector<T_out>());
155-
pC1 += MMUL::size_C;
156-
aie::store_v(pC2, C10.template to_vector<T_out>());
157-
pC2 += MMUL::size_C;
158-
aie::store_v(pC2, C11.template to_vector<T_out>());
159-
pC2 += MMUL::size_C;
186+
187+
if constexpr (c_row_maj) {
188+
aie::store_v(pC1, C00.template to_vector<T_out>());
189+
pC1 += MMUL::size_C;
190+
aie::store_v(pC1, C01.template to_vector<T_out>());
191+
pC1 += MMUL::size_C;
192+
aie::store_v(pC2, C10.template to_vector<T_out>());
193+
pC2 += MMUL::size_C;
194+
aie::store_v(pC2, C11.template to_vector<T_out>());
195+
pC2 += MMUL::size_C;
196+
} else {
197+
aie::store_v(pC1,
198+
aie::transpose(C00.template to_vector<T_out>(), r, t));
199+
pC1 += MMUL::size_C;
200+
aie::store_v(pC2,
201+
aie::transpose(C01.template to_vector<T_out>(), r, t));
202+
pC2 += MMUL::size_C;
203+
aie::store_v(pC1,
204+
aie::transpose(C10.template to_vector<T_out>(), r, t));
205+
pC1 += MMUL::size_C;
206+
aie::store_v(pC2,
207+
aie::transpose(C11.template to_vector<T_out>(), r, t));
208+
pC2 += MMUL::size_C;
209+
}
160210
}
161211
}
162212

@@ -169,6 +219,12 @@ constexpr bool is_b_row_maj = false;
169219
constexpr bool is_b_row_maj = true;
170220
#endif
171221

222+
#ifdef C_COL_MAJ
223+
constexpr bool is_c_row_maj = false;
224+
#else
225+
constexpr bool is_c_row_maj = true;
226+
#endif
227+
172228
// The following kernel definitions use mmul shapes that have been found to be
173229
// optimal for AIE2P in combination with the 2x2 mmul expanded kernel.
174230
//
@@ -195,7 +251,8 @@ static inline void matmul_vectorized_4x4x8_i16_i16(const int16 *__restrict pA,
195251
static_assert(n % (2 * t) == 0);
196252

197253
return matmul_vectorized_2x2_mmul<int16, int16, (m / r), (k / s), (n / t), r,
198-
s, t, is_b_row_maj>(pA, pB, pC);
254+
s, t, is_b_row_maj, is_c_row_maj>(pA, pB,
255+
pC);
199256
}
200257

201258
template <unsigned m, unsigned k, unsigned n>
@@ -211,7 +268,8 @@ static inline void matmul_vectorized_4x4x8_i16_i32(const int16 *__restrict pA,
211268
static_assert(n % (2 * t) == 0);
212269

213270
return matmul_vectorized_2x2_mmul<int16, int32, (m / r), (k / s), (n / t), r,
214-
s, t, is_b_row_maj>(pA, pB, pC);
271+
s, t, is_b_row_maj, is_c_row_maj>(pA, pB,
272+
pC);
215273
}
216274

217275
template <unsigned m, unsigned k, unsigned n>
@@ -228,7 +286,8 @@ matmul_vectorized_4x8x8_bf16_bf16(const bfloat16 *__restrict pA,
228286
static_assert(n % (2 * t) == 0);
229287

230288
return matmul_vectorized_2x2_mmul<bfloat16, bfloat16, (m / r), (k / s),
231-
(n / t), r, s, t, is_b_row_maj>(pA, pB, pC);
289+
(n / t), r, s, t, is_b_row_maj,
290+
is_c_row_maj>(pA, pB, pC);
232291
}
233292

234293
// Note that this shape is only possible for bf16 when using bfp16 emulation
@@ -247,7 +306,8 @@ matmul_vectorized_8x8x8_bf16_bf16(const bfloat16 *__restrict pA,
247306
static_assert(n % (2 * t) == 0);
248307

249308
return matmul_vectorized_2x2_mmul<bfloat16, bfloat16, (m / r), (k / s),
250-
(n / t), r, s, t, is_b_row_maj>(pA, pB, pC);
309+
(n / t), r, s, t, is_b_row_maj,
310+
is_c_row_maj>(pA, pB, pC);
251311
}
252312

253313
template <unsigned m, unsigned k, unsigned n>
@@ -264,7 +324,8 @@ matmul_vectorized_4x8x8_bf16_f32(const bfloat16 *__restrict pA,
264324
static_assert(n % (2 * t) == 0);
265325

266326
return matmul_vectorized_2x2_mmul<bfloat16, float, (m / r), (k / s), (n / t),
267-
r, s, t, is_b_row_maj>(pA, pB, pC);
327+
r, s, t, is_b_row_maj, is_c_row_maj>(pA, pB,
328+
pC);
268329
}
269330

270331
template <unsigned m, unsigned k, unsigned n>
@@ -281,7 +342,7 @@ matmul_vectorized_8x8x8_bf16_f32(const bfloat16 *__restrict pA,
281342
static_assert(n % (2 * t) == 0);
282343

283344
return matmul_vectorized_2x2_mmul<bfloat16, float, (m / r), (k / s), (n / t),
284-
r, s, t, is_b_row_maj>(pA, pB, pC);
345+
r, s, t, is_b_row_maj, is_c_row_maj>(pA, pB, pC);
285346
}
286347

287348
template <unsigned m, unsigned k, unsigned n>
@@ -297,7 +358,7 @@ static inline void matmul_vectorized_8x8x8_i8_i8(const int8 *__restrict pA,
297358
static_assert(n % (2 * t) == 0);
298359

299360
return matmul_vectorized_2x2_mmul<int8, int8, (m / r), (k / s), (n / t), r, s,
300-
t, is_b_row_maj>(pA, pB, pC);
361+
t, is_b_row_maj, is_c_row_maj>(pA, pB, pC);
301362
}
302363

303364
template <unsigned m, unsigned k, unsigned n>
@@ -313,7 +374,8 @@ static inline void matmul_vectorized_8x8x8_i8_i16(const int8 *__restrict pA,
313374
static_assert(n % (2 * t) == 0);
314375

315376
return matmul_vectorized_2x2_mmul<int8, int16, (m / r), (k / s), (n / t), r,
316-
s, t, is_b_row_maj>(pA, pB, pC);
377+
s, t, is_b_row_maj, is_c_row_maj>(pA, pB,
378+
pC);
317379
}
318380

319381
template <unsigned m, unsigned k, unsigned n>
@@ -329,7 +391,8 @@ static inline void matmul_vectorized_8x8x8_i8_i32(const int8 *__restrict pA,
329391
static_assert(n % (2 * t) == 0);
330392

331393
return matmul_vectorized_2x2_mmul<int8, int32, (m / r), (k / s), (n / t), r,
332-
s, t, is_b_row_maj>(pA, pB, pC);
394+
s, t, is_b_row_maj, is_c_row_maj>(pA, pB,
395+
pC);
333396
}
334397

335398
extern "C" {
@@ -418,7 +481,7 @@ extern "C" {
418481
r, s, t) \
419482
void matmul_scalar_##mlir_type_in##_##mlir_type_out( \
420483
ctype_in *a_in, ctype_in *b_in, ctype_out *c_out) { \
421-
matmul_scalar<ctype_in, ctype_out, DIM_M, DIM_K, DIM_N>(a_in, b_in, \
484+
matmul_scalar<ctype_in, ctype_out, DIM_M, DIM_K, DIM_N, is_b_row_maj, is_c_row_maj>(a_in, b_in, \
422485
c_out); \
423486
}
424487

programming_examples/basic/matrix_multiplication/common.h

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ void add_default_options(cxxopts::Options &options) {
5555
"trace_file", "where to store trace output",
5656
cxxopts::value<std::string>()->default_value("trace.txt"))(
5757
"b_col_maj", "Is B matrix in colum-major format?",
58+
cxxopts::value<int>()->default_value("0"))(
59+
"c_col_maj", "Is C matrix in colum-major format?",
5860
cxxopts::value<int>()->default_value("0"));
5961
}
6062

@@ -109,7 +111,7 @@ std::bfloat16_t get_random<std::bfloat16_t>() {
109111

110112
template <typename Tin, typename Tout, typename Tacc>
111113
void matmul(int M, int N, int K, const std::vector<Tin> A,
112-
const std::vector<Tin> B, std::vector<Tout> &C, int b_col_maj) {
114+
const std::vector<Tin> B, std::vector<Tout> &C, int b_col_maj, int c_col_maj) {
113115
for (int row = 0; row < M; row++) {
114116
for (int col = 0; col < N; col++) {
115117
Tacc running_sum = 0;
@@ -120,7 +122,11 @@ void matmul(int M, int N, int K, const std::vector<Tin> A,
120122
running_sum += Tacc(A[row * K + k] * B[k + col * K]);
121123
}
122124
}
123-
C[row * N + col] = Tout(running_sum);
125+
if (!c_col_maj) {
126+
C[row * N + col] = Tout(running_sum);
127+
} else {
128+
C[row + col * M] = Tout(running_sum);
129+
}
124130
}
125131
}
126132
}
@@ -347,14 +353,14 @@ void print_progress_bar(std::ostream &os, double progress, int len = 75) {
347353
template <typename Tin, typename Tout, typename Tacc>
348354
int verify(int M, int N, int K, std::vector<Tin> A, std::vector<Tin> B,
349355
std::vector<Tout> C, int verbosity = 0, float abs_tol = 0.5,
350-
float rel_tol = 0.05, int b_col_maj = 0) {
356+
float rel_tol = 0.05, int b_col_maj = 0, int c_col_maj = 0) {
351357
int n_errors = 0;
352358
std::vector<struct error<Tout>> errors;
353359
Tout max_rel_error = (Tout)0.0f;
354360
struct error<Tout> max_error;
355361

356362
std::vector<Tout> CRef(M * N);
357-
matmul<Tin, Tout, Tacc>(M, N, K, A, B, CRef, b_col_maj);
363+
matmul<Tin, Tout, Tacc>(M, N, K, A, B, CRef, b_col_maj, c_col_maj);
358364

359365
for (int row = 0; row < M; row++) {
360366
for (int col = 0; col < N; col++) {
@@ -394,7 +400,7 @@ template <typename Tin, typename Tout, typename Tacc>
394400
int verify_stochastic(int M, int N, int K, std::vector<Tin> A,
395401
std::vector<Tin> B, std::vector<Tout> C, int n_samples,
396402
int verbosity = 0, float abs_tol = 0.5,
397-
float rel_tol = 0.05, int b_col_maj = 0) {
403+
float rel_tol = 0.05, int b_col_maj = 0, int c_col_maj = 0) {
398404
std::mt19937 rng;
399405
auto rows = std::views::iota(0, M);
400406
auto cols = std::views::iota(0, N);
@@ -420,8 +426,14 @@ int verify_stochastic(int M, int N, int K, std::vector<Tin> A,
420426
print_progress_bar(std::cerr, progress);
421427
}
422428
Tout ref = mul_acc<Tin, Tout, Tacc>(M, N, K, row, col, A, B, b_col_maj);
429+
Tout observed;
430+
if (!c_col_maj) {
431+
observed = C[row * N + col];
432+
} else {
433+
observed = C[row + col * M];
434+
}
423435
std::optional<struct error<Tout>> error = verify_single(
424-
std::cout, row, col, ref, C[row * N + col], abs_tol, rel_tol);
436+
std::cout, row, col, ref, observed, abs_tol, rel_tol);
425437
if (error.has_value()) {
426438
if (n_errors < max_printable_errors) {
427439
errors.push_back(*error);

programming_examples/basic/matrix_multiplication/test.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ int main(int argc, const char *argv[]) {
6767
int n_warmup_iterations = vm["warmup"].as<int>();
6868
int trace_size = vm["trace_sz"].as<int>();
6969
int b_col_maj = vm["b_col_maj"].as<int>();
70+
int c_col_maj = vm["c_col_maj"].as<int>();
7071

7172
// Fix the seed to ensure reproducibility in CI.
7273
srand(1726250518); // srand(time(NULL));
@@ -255,10 +256,10 @@ int main(int argc, const char *argv[]) {
255256
errors = matmul_common::verify_stochastic<A_DATATYPE, C_DATATYPE,
256257
ACC_DATATYPE>(
257258
M, N, K, AVec, BVec, CVec, verify_stochastic_n_samples, verbosity,
258-
abs_tol, rel_tol, b_col_maj);
259+
abs_tol, rel_tol, b_col_maj, c_col_maj);
259260
} else {
260261
errors = matmul_common::verify<A_DATATYPE, C_DATATYPE, ACC_DATATYPE>(
261-
M, N, K, AVec, BVec, CVec, verbosity, abs_tol, rel_tol, b_col_maj);
262+
M, N, K, AVec, BVec, CVec, verbosity, abs_tol, rel_tol, b_col_maj, c_col_maj);
262263
}
263264
auto vstop = std::chrono::system_clock::now();
264265
float vtime =

programming_examples/basic/matrix_multiplication/whole_array/Makefile

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,20 @@ n?=32
2424

2525
n_aie_cols?=4
2626
b_col_maj?=0
27+
c_col_maj?=0
2728

2829
kernels=mm_${m}x${k}x${n}
29-
aieargs+=-m $m -k $k -n $n --n-aie-cols ${n_aie_cols} --b-col-maj ${b_col_maj}
30-
runargs+=--b_col_maj ${b_col_maj}
30+
aieargs+=-m $m -k $k -n $n --n-aie-cols ${n_aie_cols} --b-col-maj ${b_col_maj}
31+
runargs+=--b_col_maj ${b_col_maj}
3132
target_suffix=${M}x${K}x${N}_${m}x${k}x${n}_${n_aie_cols}c
3233
use_placed?=0
3334
use_iron?=0
35+
use_scalar?=0
36+
37+
ifeq (${c_col_maj}, 1)
38+
aieargs+=--c-col-maj ${c_col_maj}
39+
runargs+=--c_col_maj ${c_col_maj}
40+
endif
3441

3542
# set this flag to 1 for linear buffer allocation
3643
# else, 0 for bank-aware
@@ -57,6 +64,15 @@ KERNEL_DEFINES=-D${dtype_in}_${dtype_out}_ONLY -DDIM_M=${m} -DDIM_K=${k} -DDIM_N
5764
ifeq (${b_col_maj}, 1)
5865
KERNEL_DEFINES+=-DB_COL_MAJ
5966
endif
67+
ifeq (${c_col_maj}, 1)
68+
KERNEL_DEFINES+=-DC_COL_MAJ
69+
endif
70+
ifeq (${use_scalar}, 1)
71+
KERNEL_DEFINES+=-DSCALAR_ONLY
72+
aieargs+= --scalar 1
73+
else
74+
KERNEL_DEFINES+=-DVECTORIZED_ONLY
75+
endif
6076

6177
include ${srcdir}/../makefile-common
6278

File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)