Skip to content

Commit 39c90f9

Browse files
authored
Merge pull request #5380 from quic/topic/sgemm_direct_sme1_alpha_beta
SME1 based direct kernel (with alpha and beta) for cblas_sgemm level 3
2 parents ac8cbfd + eae0abf commit 39c90f9

File tree

9 files changed

+266
-1
lines changed

9 files changed

+266
-1
lines changed

common_level3.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K,
5252
float * B, BLASLONG strideB,
5353
float * R, BLASLONG strideR);
5454

55+
void sgemm_direct_alpha_beta(BLASLONG M, BLASLONG N, BLASLONG K,
56+
float alpha,
57+
float * A, BLASLONG strideA,
58+
float * B, BLASLONG strideB,
59+
float beta,
60+
float * R, BLASLONG strideR);
61+
5562
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
5663

5764
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,

common_param.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
256256
#endif
257257
#ifdef ARCH_ARM64
258258
void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG);
259+
void (*sgemm_direct_alpha_beta) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
259260
#endif
260261

261262

common_s.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant
5151
#define SGEMM_DIRECT sgemm_direct
52+
#define SGEMM_DIRECT_ALPHA_BETA sgemm_direct_alpha_beta
5253

5354
#define SGEMM_ONCOPY sgemm_oncopy
5455
#define SGEMM_OTCOPY sgemm_otcopy
@@ -218,6 +219,7 @@
218219
#elif ARCH_ARM64
219220
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant
220221
#define SGEMM_DIRECT gotoblas -> sgemm_direct
222+
#define SGEMM_DIRECT_ALPHA_BETA gotoblas -> sgemm_direct_alpha_beta
221223
#endif
222224

223225
#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy

interface/gemm.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,9 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
441441
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) {
442442
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
443443
return;
444+
}else if (order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) {
445+
SGEMM_DIRECT_ALPHA_BETA(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
446+
return;
444447
}
445448
#endif
446449
#endif

kernel/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,11 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
255255
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPERFORMANT}" "" "gemm_direct_performant" false "" "" false SINGLE)
256256
elseif (ARM64)
257257
set (SGEMMDIRECTKERNEL sgemm_direct_arm64_sme1.c)
258+
set (SGEMMDIRECTKERNEL_ALPHA_BETA sgemm_direct_alpha_beta_arm64_sme1.c)
258259
set (SGEMMDIRECTSMEKERNEL sgemm_direct_sme1.S)
259260
set (SGEMMDIRECTPREKERNEL sgemm_direct_sme1_preprocess.S)
260261
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL}" "" "gemm_direct" false "" "" false SINGLE)
262+
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL_ALPHA_BETA}" "" "gemm_direct_alpha_beta" false "" "" false SINGLE)
261263
if (HAVE_SME)
262264
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTSMEKERNEL}" "" "gemm_direct_sme1" false "" "" false SINGLE)
263265
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPREKERNEL}" "" "gemm_direct_sme1_preprocess" false "" "" false SINGLE)

kernel/Makefile.L3

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ ifeq ($(TARGET_CORE), ARMV9SME)
132132
HAVE_SME = 1
133133
endif
134134
SGEMMDIRECTKERNEL = sgemm_direct_arm64_sme1.c
135+
SGEMMDIRECTKERNEL_ALPHA_BETA = sgemm_direct_alpha_beta_arm64_sme1.c
135136
endif
136137
endif
137138
endif
@@ -208,7 +209,8 @@ SKERNELOBJS += \
208209
endif
209210
ifeq ($(ARCH), arm64)
210211
SKERNELOBJS += \
211-
sgemm_direct$(TSUFFIX).$(SUFFIX)
212+
sgemm_direct$(TSUFFIX).$(SUFFIX) \
213+
sgemm_direct_alpha_beta$(TSUFFIX).$(SUFFIX)
212214
ifdef HAVE_SME
213215
SKERNELOBJS += \
214216
sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) \
@@ -969,6 +971,8 @@ endif
969971
ifeq ($(ARCH), arm64)
970972
$(KDIR)sgemm_direct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL)
971973
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
974+
$(KDIR)sgemm_direct_alpha_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL_ALPHA_BETA)
975+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
972976
ifdef HAVE_SME
973977
$(KDIR)sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) :
974978
$(CC) $(CFLAGS) -c $(KERNELDIR)/sgemm_direct_sme1.S -UDOUBLE -UCOMPLEX -o $@
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
/*
2+
Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
3+
SPDX-License-Identifier: BSD-3-Clause-Clear
4+
*/
5+
6+
#include "common.h"
7+
#include <stdlib.h>
8+
#include <inttypes.h>
9+
#include <math.h>
10+
#include "sme_abi.h"
11+
#if defined(HAVE_SME)
12+
13+
#if defined(__ARM_FEATURE_SME) && defined(__clang__) && __clang_major__ >= 16
14+
#include <arm_sme.h>
15+
#endif
16+
17+
/* Function prototypes */
18+
extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\
19+
const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess");
20+
21+
/* Function Definitions */
22+
static uint64_t sve_cntw() {
23+
uint64_t cnt;
24+
asm volatile(
25+
"rdsvl %[res], #1\n"
26+
"lsr %[res], %[res], #2\n"
27+
: [res] "=r" (cnt) ::
28+
);
29+
return cnt;
30+
}
31+
32+
#if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_LOCALLY_STREAMING) && defined(__clang__) && __clang_major__ >= 16
33+
// Outer product kernel.
34+
// Computes a 2SVL x 2SVL block of C, utilizing all four FP32 tiles of ZA.
35+
__attribute__((always_inline)) inline void
36+
kernel_2x2(const float *A, const float *B, float *C, size_t shared_dim,
37+
size_t ldc, size_t block_rows, size_t block_cols, float alpha, float beta)
38+
__arm_out("za") __arm_streaming {
39+
40+
const uint64_t svl = svcntw();
41+
size_t ldb = ldc;
42+
// Predicate set-up
43+
svbool_t pg = svptrue_b32();
44+
svbool_t pg_a_0 = svwhilelt_b32_u64(0, block_rows);
45+
svbool_t pg_a_1 = svwhilelt_b32_u64(svl, block_rows);
46+
47+
svbool_t pg_b_0 = svwhilelt_b32_u64(0, block_cols);
48+
svbool_t pg_b_1 = svwhilelt_b32_u64(svl, block_cols);
49+
50+
#define pg_c_0 pg_b_0
51+
#define pg_c_1 pg_b_1
52+
53+
svzero_za();
54+
svfloat32_t beta_vec = svdup_f32(beta);
55+
// Load C to ZA
56+
for (size_t i = 0; i < MIN(svl, block_rows); i++) {
57+
svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]);
58+
row_c_0 = svmul_x(pg, beta_vec, row_c_0);
59+
svwrite_hor_za32_f32_m(/*tile*/0, /*slice*/i, pg_c_0, row_c_0);
60+
61+
svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]);
62+
row_c_1 = svmul_x(pg, beta_vec, row_c_1);
63+
svwrite_hor_za32_f32_m(/*tile*/1, /*slice*/i, pg_c_1, row_c_1);
64+
}
65+
for (size_t i = svl; i < block_rows; i++) {
66+
svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]);
67+
row_c_0 = svmul_x(pg, beta_vec, row_c_0);
68+
svwrite_hor_za32_f32_m(/*tile*/2, /*slice*/i, pg_c_0, row_c_0);
69+
70+
svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]);
71+
row_c_1 = svmul_x(pg, beta_vec, row_c_1);
72+
svwrite_hor_za32_f32_m(/*tile*/3, /*slice*/i, pg_c_1, row_c_1);
73+
}
74+
75+
svfloat32_t alpha_vec = svdup_f32(alpha);
76+
// Iterate through shared dimension (K)
77+
for (size_t k = 0; k < shared_dim; k++) {
78+
// Load column of A
79+
svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * svl]);
80+
col_a_0 = svmul_x(pg, alpha_vec, col_a_0);
81+
svfloat32_t col_a_1 = svld1(pg_a_1, &A[(k + shared_dim) * svl]);
82+
col_a_1 = svmul_x(pg, alpha_vec, col_a_1);
83+
// Load row of B
84+
svfloat32_t row_b_0 = svld1(pg_b_0, &B[k * ldb]);
85+
svfloat32_t row_b_1 = svld1(pg_b_1, &B[k * ldb + svl]);
86+
// Perform outer product
87+
svmopa_za32_m(/*tile*/0, pg, pg, col_a_0, row_b_0);
88+
svmopa_za32_m(/*tile*/1, pg, pg, col_a_0, row_b_1);
89+
svmopa_za32_m(/*tile*/2, pg, pg, col_a_1, row_b_0);
90+
svmopa_za32_m(/*tile*/3, pg, pg, col_a_1, row_b_1);
91+
}
92+
93+
// Store to C from ZA
94+
for (size_t i = 0; i < MIN(svl, block_rows); i++) {
95+
svst1_hor_za32(/*tile*/0, /*slice*/i, pg_c_0, &C[i * ldc]);
96+
svst1_hor_za32(/*tile*/1, /*slice*/i, pg_c_1, &C[i * ldc + svl]);
97+
}
98+
for (size_t i = svl; i < block_rows; i++) {
99+
svst1_hor_za32(/*tile*/2, /*slice*/i, pg_c_0, &C[i * ldc]);
100+
svst1_hor_za32(/*tile*/3, /*slice*/i, pg_c_1, &C[i * ldc + svl]);
101+
}
102+
}
103+
104+
__arm_new("za") __arm_locally_streaming
105+
void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
106+
const float *ba, const float *restrict bb, const float* beta,\
107+
float *restrict C) {
108+
109+
const uint64_t num_rows = m;
110+
const uint64_t num_cols = n;
111+
112+
const float *restrict a_ptr = ba;
113+
const float *restrict b_ptr = bb;
114+
float *restrict c_ptr = C;
115+
116+
const uint64_t svl = svcntw();
117+
const uint64_t ldc = n;
118+
119+
// Block over rows of C (panels of A)
120+
uint64_t row_idx = 0;
121+
122+
// 2x2 loop
123+
uint64_t row_batch = 2*svl;
124+
125+
// Block over row dimension of C
126+
for (; row_idx < num_rows; row_idx += row_batch) {
127+
row_batch = MIN(row_batch, num_rows - row_idx);
128+
uint64_t col_idx = 0;
129+
uint64_t col_batch = 2*svl;
130+
131+
// Block over column dimension of C
132+
for (; col_idx < num_cols; col_idx += col_batch) {
133+
col_batch = MIN(col_batch, num_cols - col_idx);
134+
135+
kernel_2x2(&a_ptr[row_idx * k], &b_ptr[col_idx],
136+
&c_ptr[row_idx * ldc + col_idx], k,
137+
ldc, row_batch, col_batch, *alpha, *beta);
138+
}
139+
}
140+
return;
141+
}
142+
143+
#else
144+
void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
145+
const float *ba, const float *restrict bb, const float* beta,\
146+
float *restrict C){}
147+
#endif
148+
149+
/*void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K,\
150+
float * __restrict A, BLASLONG strideA, float * __restrict B,\
151+
BLASLONG strideB , float * __restrict R, BLASLONG strideR)
152+
*/
153+
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\
154+
BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\
155+
float beta, float * __restrict R, BLASLONG strideR){
156+
157+
uint64_t m_mod, vl_elms;
158+
159+
vl_elms = sve_cntw();
160+
161+
m_mod = ceil((double)M/(double)vl_elms) * vl_elms;
162+
163+
float *A_mod = (float *) malloc(m_mod*K*sizeof(float));
164+
165+
/* Prevent compiler optimization by reading from memory instead
166+
* of reading directly from vector (z) registers.
167+
* */
168+
asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7",
169+
"p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15",
170+
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7",
171+
"z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15",
172+
"z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23",
173+
"z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");
174+
175+
/* Pre-process the left matrix to make it suitable for
176+
matrix sum of outer-product calculation
177+
*/
178+
sgemm_direct_sme1_preprocess(M, K, A, A_mod);
179+
180+
asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7",
181+
"p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15",
182+
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7",
183+
"z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15",
184+
"z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23",
185+
"z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");
186+
187+
/* Calculate C = alpha*A*B + beta*C */
188+
sgemm_direct_alpha_beta_sme1_2VLx2VL(M, K, N, &alpha, A_mod, B, &beta, R);
189+
190+
free(A_mod);
191+
}
192+
193+
#else
194+
195+
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\
196+
BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\
197+
float beta, float * __restrict R, BLASLONG strideR){}
198+
199+
#endif

kernel/arm64/sme_abi.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/***************************************************************************
2+
* Copyright (c) 2024, The OpenBLAS Project
3+
* All rights reserved.
4+
* Redistribution and use in source and binary forms, with or without
5+
* modification, are permitted provided that the following conditions are
6+
* met:
7+
* 1. Redistributions of source code must retain the above copyright
8+
* notice, this list of conditions and the following disclaimer.
9+
* 2. Redistributions in binary form must reproduce the above copyright
10+
* notice, this list of conditions and the following disclaimer in
11+
* the documentation and/or other materials provided with the
12+
* distribution.
13+
* 3. Neither the name of the OpenBLAS project nor the names of
14+
* its contributors may be used to endorse or promote products
15+
* derived from this software without specific prior written permission.
16+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20+
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
22+
* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
23+
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
24+
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
25+
* THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
* *****************************************************************************/
27+
28+
#pragma once
29+
30+
#include <stdlib.h>
31+
32+
/**
33+
* * These are SME ABI routines for saving & restoring SME state.
34+
* * They are typically provided by a compiler runtime library such
35+
* * as libgcc or compiler-rt, but support for these routines is not
36+
* * yet available on all platforms.
37+
* *
38+
* * Define these as aborting stubs so that we loudly fail on nested
39+
* * usage of SME state.
40+
* *
41+
* * These are defined as weak symbols so that a compiler runtime can
42+
* * override them if supported.
43+
* */
44+
__attribute__((weak)) void __arm_tpidr2_save() { abort(); }
45+
__attribute__((weak)) void __arm_tpidr2_restore() { abort(); }
46+

kernel/setparam-ref.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ gotoblas_t TABLE_NAME = {
215215
#endif
216216
#ifdef ARCH_ARM64
217217
sgemm_directTS,
218+
sgemm_direct_alpha_betaTS,
218219
#endif
219220

220221
sgemm_kernelTS, sgemm_betaTS,

0 commit comments

Comments
 (0)