Skip to content

Commit 6dda7cf

Browse files
committed
Support for SME1 based ssymm_direct kernel for cblas_ssymm level 3 API
1 parent e2f9f57 commit 6dda7cf

File tree

9 files changed

+306
-0
lines changed

9 files changed

+306
-0
lines changed

common_level3.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,19 @@ void sgemm_direct_alpha_beta(BLASLONG M, BLASLONG N, BLASLONG K,
5959
float beta,
6060
float * R, BLASLONG strideR);
6161

62+
void ssymm_direct_alpha_betaLU(BLASLONG M, BLASLONG N,
63+
float alpha,
64+
float * A, BLASLONG strideA,
65+
float * B, BLASLONG strideB,
66+
float beta,
67+
float * R, BLASLONG strideR);
68+
void ssymm_direct_alpha_betaLL(BLASLONG M, BLASLONG N,
69+
float alpha,
70+
float * A, BLASLONG strideA,
71+
float * B, BLASLONG strideB,
72+
float beta,
73+
float * R, BLASLONG strideR);
74+
6275
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
6376

6477
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,

common_param.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
257257
#ifdef ARCH_ARM64
258258
void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG);
259259
void (*sgemm_direct_alpha_beta) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
260+
void (*ssymm_direct_alpha_betaLU) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
261+
void (*ssymm_direct_alpha_betaLL) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
260262
#endif
261263

262264

common_s.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant
5151
#define SGEMM_DIRECT sgemm_direct
5252
#define SGEMM_DIRECT_ALPHA_BETA sgemm_direct_alpha_beta
53+
#define SSYMM_DIRECT_ALPHA_BETA_LU ssymm_direct_alpha_betaLU
54+
#define SSYMM_DIRECT_ALPHA_BETA_LL ssymm_direct_alpha_betaLL
5355

5456
#define SGEMM_ONCOPY sgemm_oncopy
5557
#define SGEMM_OTCOPY sgemm_otcopy
@@ -220,6 +222,8 @@
220222
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant
221223
#define SGEMM_DIRECT gotoblas -> sgemm_direct
222224
#define SGEMM_DIRECT_ALPHA_BETA gotoblas -> sgemm_direct_alpha_beta
225+
#define SSYMM_DIRECT_ALPHA_BETA_LU gotoblas -> ssymm_direct_alpha_betaLU
226+
#define SSYMM_DIRECT_ALPHA_BETA_LL gotoblas -> ssymm_direct_alpha_betaLL
223227
#endif
224228

225229
#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy

interface/symm.c

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,22 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo,
269269

270270
PRINT_DEBUG_CNAME;
271271

272+
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
273+
#if defined(ARCH_ARM64) && (defined(USE_SSYMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
274+
#if defined(DYNAMIC_ARCH)
275+
if (support_sme1())
276+
#endif
277+
if (order == CblasRowMajor && m == lda && n == ldb && n == ldc)
278+
{
279+
if (Side == CblasLeft && Uplo == CblasUpper) {
280+
SSYMM_DIRECT_ALPHA_BETA_LU(m, n, alpha, a, lda, b, ldb, beta, c, ldc); return;
281+
}
282+
else if (Side == CblasLeft && Uplo == CblasLower) {
283+
SSYMM_DIRECT_ALPHA_BETA_LL(m, n, alpha, a, lda, b, ldb, beta, c, ldc); return;
284+
}
285+
}
286+
#endif
287+
#endif
272288
#ifndef COMPLEX
273289
args.alpha = (void *)α
274290
args.beta = (void *)β

kernel/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
241241
if (X86_64 OR ARM64)
242242
set(USE_DIRECT_SGEMM true)
243243
endif()
244+
set(USE_DIRECT_SSYMM false)
245+
if (ARM64)
246+
set(USE_DIRECT_SSYMM true)
247+
endif()
244248
if (UC_TARGET_CORE MATCHES ARMV9SME)
245249
set (HAVE_SME true)
246250
endif ()
@@ -267,6 +271,13 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
267271
endif ()
268272
endif()
269273

274+
if (USE_DIRECT_SSYMM)
275+
if (ARM64)
276+
set (SSYMMDIRECTKERNEL_ALPHA_BETA ssymm_direct_alpha_beta_arm64_sme1.c)
277+
GenerateNamedObjects("${KERNELDIR}/${SSYMMDIRECTKERNEL_ALPHA_BETA}" "" "symm_direct_alpha_beta" false "" "" false SINGLE)
278+
endif ()
279+
endif()
280+
270281
foreach (float_type SINGLE DOUBLE)
271282
string(SUBSTRING ${float_type} 0 1 float_char)
272283
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type})

kernel/Makefile.L3

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ endif
5252
ifeq ($(ARCH), arm64)
5353
USE_TRMM = 1
5454
USE_DIRECT_SGEMM = 1
55+
USE_DIRECT_SSYMM = 1
5556
endif
5657

5758
ifeq ($(ARCH), riscv64)
@@ -137,6 +138,17 @@ endif
137138
endif
138139
endif
139140

141+
ifdef USE_DIRECT_SSYMM
142+
ifndef SSYMMDIRECTKERNEL_ALPHA_BETA
143+
ifeq ($(ARCH), arm64)
144+
ifeq ($(TARGET_CORE), ARMV9SME)
145+
HAVE_SME = 1
146+
endif
147+
SSYMMDIRECTKERNEL_ALPHA_BETA = ssymm_direct_alpha_beta_arm64_sme1.c
148+
endif
149+
endif
150+
endif
151+
140152
ifeq ($(BUILD_BFLOAT16), 1)
141153
ifndef BGEMMKERNEL
142154
BGEMM_BETA = ../generic/gemm_beta.c
@@ -220,6 +232,14 @@ endif
220232
endif
221233
endif
222234

235+
ifdef USE_DIRECT_SSYMM
236+
ifeq ($(ARCH), arm64)
237+
SKERNELOBJS += \
238+
ssymm_direct_alpha_betaLU$(TSUFFIX).$(SUFFIX) \
239+
ssymm_direct_alpha_betaLL$(TSUFFIX).$(SUFFIX)
240+
endif
241+
endif
242+
223243
ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" ""
224244
DKERNELOBJS += \
225245
dgemm_beta$(TSUFFIX).$(SUFFIX) \
@@ -982,6 +1002,15 @@ endif
9821002
endif
9831003
endif
9841004

1005+
ifdef USE_DIRECT_SSYMM
1006+
ifeq ($(ARCH), arm64)
1007+
$(KDIR)ssymm_direct_alpha_betaLU$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYMMDIRECTKERNEL_ALPHA_BETA)
1008+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DLEFT -DUPPER $< -o $@
1009+
$(KDIR)ssymm_direct_alpha_betaLL$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYMMDIRECTKERNEL_ALPHA_BETA)
1010+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DLEFT -DLOWER $< -o $@
1011+
endif
1012+
endif
1013+
9851014
ifeq ($(BUILD_BFLOAT16), 1)
9861015
$(KDIR)bgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BGEMMKERNEL)
9871016
$(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
/*
2+
Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
3+
SPDX-License-Identifier: BSD-3-Clause-Clear
4+
*/
5+
6+
#include "common.h"
7+
#include <stdlib.h>
8+
#include <inttypes.h>
9+
#include <math.h>
10+
#include "sme_abi.h"
11+
#if defined(HAVE_SME)
12+
13+
#if defined(__ARM_FEATURE_SME) && defined(__clang__) && __clang_major__ >= 16
14+
#include <arm_sme.h>
15+
#endif
16+
17+
/* Function prototypes */
18+
extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\
19+
const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess");
20+
21+
extern void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
22+
const float *ba, const float *restrict bb, const float* beta,\
23+
float *restrict C);
24+
/* Function Definitions */
25+
static uint64_t sve_cntw() {
26+
uint64_t cnt;
27+
asm volatile(
28+
"rdsvl %[res], #1\n"
29+
"lsr %[res], %[res], #2\n"
30+
: [res] "=r" (cnt) ::
31+
);
32+
return cnt;
33+
}
34+
35+
#if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_LOCALLY_STREAMING) && defined(__clang__) && __clang_major__ >= 16
36+
37+
__arm_new("za") __arm_locally_streaming
38+
static void ssymm_direct_sme1_preprocessLU(uint64_t nbr, uint64_t nbc,
39+
const float *restrict a, float *restrict a_mod)
40+
{
41+
// const uint64_t num_rows = nbr;
42+
// const uint64_t num_cols = nbc;
43+
const uint64_t svl = svcntw();
44+
uint64_t row_batch = svl;
45+
46+
float *restrict pSrc;
47+
float *restrict pDst;
48+
for (uint64_t row_idx = 0; row_idx < nbr; row_idx += row_batch)
49+
{
50+
row_batch = MIN(row_batch, nbr - row_idx);
51+
52+
// Fill in the lower triangle and Transpose 1SVL x N panel of A
53+
uint64_t col_batch = svl;
54+
55+
for (uint64_t col_idx = 0; col_idx < nbc; col_idx += col_batch)
56+
{
57+
svzero_za();
58+
59+
if (col_idx == row_idx)
60+
{
61+
pSrc = &a[(row_idx)*nbc + col_idx];
62+
pDst = &a_mod[(col_idx)*svl + row_idx * nbc];
63+
// Load horizontal slices, filling lower elements
64+
const svbool_t pg_row = svwhilelt_b32_u64(col_idx, nbc);
65+
for (int64_t row = row_batch - 1; row >= 0; row--)
66+
{
67+
svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]);
68+
svld1_ver_za32(0, row, pg_row, &pSrc[row * nbc]);
69+
}
70+
// Save vertical slices
71+
col_batch = MIN(col_batch, nbc - col_idx);
72+
for (uint64_t col = 0; col < col_batch; col++)
73+
{
74+
svst1_ver_za32(0, col, svptrue_b32(), &pDst[col * svl]);
75+
}
76+
}
77+
else if (col_idx > row_idx)
78+
{
79+
pSrc = &a[(row_idx)*nbc + col_idx];
80+
pDst = &a_mod[(col_idx)*svl + row_idx * nbc];
81+
// Load horizontal slices
82+
const svbool_t pg_row = svwhilelt_b32_u64(col_idx, nbc);
83+
for (uint64_t row = 0; row < row_batch; row++)
84+
{
85+
svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]);
86+
}
87+
// Save vertical slices
88+
col_batch = MIN(col_batch, nbc - col_idx);
89+
for (uint64_t col = 0; col < col_batch; col++)
90+
{
91+
svst1_ver_za32(0, col, svptrue_b32(), &pDst[col * svl]);
92+
}
93+
}
94+
else if (col_idx < row_idx)
95+
{
96+
pSrc = &a[row_idx + col_idx * nbc];
97+
pDst = &a_mod[(col_idx)*svl + row_idx * nbc];
98+
// Load horizontal slices
99+
const svbool_t pg_row = svwhilelt_b32_u64(row_idx, nbc);
100+
for (uint64_t row = 0; row < svl; row++)
101+
{
102+
svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]);
103+
}
104+
// Save vertical slices
105+
col_batch = MIN(col_batch, nbc - col_idx);
106+
for (uint64_t col = 0; col < svl; col++)
107+
{
108+
svst1_hor_za32(0, col, svptrue_b32(), &pDst[col * svl]);
109+
}
110+
}
111+
}
112+
}
113+
}
114+
115+
//
116+
__arm_new("za") __arm_locally_streaming
117+
static void ssymm_direct_sme1_preprocessLL(uint64_t nbr, uint64_t nbc,
118+
const float *restrict a, float *restrict a_mod)
119+
{
120+
// const uint64_t num_rows = nbr;
121+
const uint64_t svl = svcntw();
122+
uint64_t row_batch = svl;
123+
124+
float *restrict pSrc;
125+
float *restrict pDst;
126+
for (uint64_t row_idx = 0; row_idx < nbr; row_idx += row_batch)
127+
{
128+
row_batch = MIN(row_batch, nbr - row_idx);
129+
130+
// Fill in the upper triangle and Transpose 1SVL x N panel of A
131+
uint64_t col_batch = svl;
132+
133+
for (uint64_t col_idx = 0; col_idx < nbc; col_idx += col_batch)
134+
{
135+
svzero_za();
136+
137+
if (col_idx == row_idx)
138+
{
139+
pSrc = &a[(row_idx)*nbc + col_idx];
140+
pDst = &a_mod[(col_idx)*svl + row_idx * nbc];
141+
// Load horizontal slices, filling upper elements
142+
const svbool_t pg_row = svwhilelt_b32_u64(col_idx, nbc);
143+
for (uint64_t row = 0; row < row_batch; row++)
144+
{
145+
svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]);
146+
svld1_ver_za32(0, row, pg_row, &pSrc[row * nbc]);
147+
}
148+
// Save vertical slices
149+
col_batch = MIN(col_batch, nbc - col_idx);
150+
for (uint64_t col = 0; col < col_batch; col++)
151+
{
152+
svst1_ver_za32(0, col, svptrue_b32(), &pDst[col * svl]);
153+
}
154+
}
155+
else if (col_idx > row_idx)
156+
{
157+
pSrc = &a[row_idx + col_idx * nbc];
158+
pDst = &a_mod[(col_idx)*svl + row_idx * nbc];
159+
// Load horizontal slices
160+
const svbool_t pg_row = svptrue_b32();
161+
for (uint64_t row = 0; row < row_batch; row++)
162+
{
163+
svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]);
164+
}
165+
// Save vertical slices
166+
col_batch = MIN(col_batch, nbc - col_idx);
167+
for (uint64_t col = 0; col < col_batch; col++)
168+
{
169+
svst1_hor_za32(0, col, svptrue_b32(), &pDst[col * svl]);
170+
}
171+
}
172+
else if (col_idx < row_idx)
173+
{
174+
pSrc = &a[(row_idx)*nbc + col_idx];
175+
pDst = &a_mod[(col_idx)*svl + row_idx * nbc];
176+
// Load horizontal slices
177+
const svbool_t pg_row = svwhilelt_b32_u64(col_idx, nbc);
178+
for (uint64_t row = 0; row < row_batch; row++)
179+
{
180+
svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]);
181+
}
182+
// Save vertical slices
183+
col_batch = MIN(col_batch, nbc - col_idx);
184+
for (uint64_t col = 0; col < col_batch; col++)
185+
{
186+
svst1_ver_za32(0, col, svptrue_b32(), &pDst[col * svl]);
187+
}
188+
}
189+
}
190+
}
191+
}
192+
193+
#endif
194+
195+
//
196+
void CNAME(BLASLONG M, BLASLONG N, float alpha, float *__restrict A,
197+
BLASLONG strideA, float *__restrict B, BLASLONG strideB,
198+
float beta, float *__restrict R, BLASLONG strideR)
199+
{
200+
uint64_t vl_elms = sve_cntw(); // vl_elem = 16
201+
uint64_t m_mod = ceil((double)M / (double)vl_elms) * vl_elms;
202+
203+
/* Pre-process the left matrix to make it suitable for
204+
matrix sum of outer-product calculation
205+
*/
206+
float *A_mod = (float *)malloc(m_mod * M * sizeof(float));
207+
208+
#if defined(UPPER)
209+
ssymm_direct_sme1_preprocessLU(M, M, A, A_mod);
210+
#elif defined(LOWER)
211+
ssymm_direct_sme1_preprocessLL(M, M, A, A_mod);
212+
#endif
213+
214+
/* Calculate C = alpha*A*B + beta*C */
215+
sgemm_direct_alpha_beta_sme1_2VLx2VL(M, M, N, &alpha, A_mod, B, &beta, R);
216+
free(A_mod);
217+
}
218+
219+
#else
220+
221+
void CNAME (BLASLONG M, BLASLONG N, float alpha, float * __restrict A,\
222+
BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\
223+
float beta, float * __restrict R, BLASLONG strideR){}
224+
225+
#endif

kernel/setparam-ref.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ gotoblas_t TABLE_NAME = {
216216
#ifdef ARCH_ARM64
217217
sgemm_directTS,
218218
sgemm_direct_alpha_betaTS,
219+
ssymm_direct_alpha_betaLUTS,
220+
ssymm_direct_alpha_betaLLTS,
219221
#endif
220222

221223
sgemm_kernelTS, sgemm_betaTS,

0 commit comments

Comments
 (0)