Skip to content

Commit 5aeb609

Browse files
committed
RV64: Add bounds assertions
Signed-off-by: Hanno Becker <beckphan@amazon.co.uk>
1 parent 28472b7 commit 5aeb609

File tree

6 files changed

+306
-58
lines changed

6 files changed

+306
-58
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ jobs:
150150
compile_mode: native
151151
cflags: "-DMLKEM_DEBUG -fsanitize=address -fsanitize=undefined -fno-sanitize-recover=all"
152152
check_namespace: 'false'
153-
- name: build + test (+debug+memsan+ubsan, cross, opt)
153+
- name: build + test (+debug, cross, opt)
154154
uses: ./.github/actions/multi-functest
155155
# There is no native code yet on PPC64LE, riscv32 or AArch64_be, so no point running opt tests
156156
if: ${{ matrix.target.mode != 'native' && (matrix.target.arch != 'ppc64le' && matrix.target.arch != 'riscv32' && matrix.target.arch != 'aarch64_be') }}
@@ -159,7 +159,7 @@ jobs:
159159
nix-cache: ${{ matrix.target.mode == 'native' && 'false' || 'true' }}
160160
gh_token: ${{ secrets.GITHUB_TOKEN }}
161161
compile_mode: ${{ matrix.target.mode }}
162-
cflags: "-DMLKEM_DEBUG -fsanitize=address -fsanitize=undefined -fno-sanitize-recover=all"
162+
cflags: "-DMLKEM_DEBUG"
163163
opt: 'opt'
164164
backend_tests:
165165
name: AArch64 FIPS202 backends (${{ matrix.backend }})

mlkem/mlkem_native.S

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,12 @@
596596
#undef MLK_RVV_MONT_R2
597597
#undef MLK_RVV_QI
598598
#undef MLK_RVV_VLEN
599+
#undef mlk_assert_abs_bound_int16m1
600+
#undef mlk_assert_abs_bound_int16m2
601+
#undef mlk_assert_bound_int16m1
602+
#undef mlk_assert_bound_int16m2
603+
#undef mlk_debug_check_bounds_int16m1
604+
#undef mlk_debug_check_bounds_int16m2
599605
#endif /* MLK_SYS_RISCV64 */
600606
#endif /* MLK_CONFIG_USE_NATIVE_BACKEND_ARITH */
601607
#endif /* !MLK_CONFIG_MONOBUILD_KEEP_SHARED_HEADERS */

mlkem/mlkem_native.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
#include "src/native/x86_64/src/rej_uniform_table.c"
8686
#endif
8787
#if defined(MLK_SYS_RISCV64)
88+
#include "src/native/riscv64/src/rv64v_debug.c"
8889
#include "src/native/riscv64/src/rv64v_poly.c"
8990
#endif
9091
#endif /* MLK_CONFIG_USE_NATIVE_BACKEND_ARITH */
@@ -584,6 +585,12 @@
584585
#undef MLK_RVV_MONT_R2
585586
#undef MLK_RVV_QI
586587
#undef MLK_RVV_VLEN
588+
#undef mlk_assert_abs_bound_int16m1
589+
#undef mlk_assert_abs_bound_int16m2
590+
#undef mlk_assert_bound_int16m1
591+
#undef mlk_assert_bound_int16m2
592+
#undef mlk_debug_check_bounds_int16m1
593+
#undef mlk_debug_check_bounds_int16m2
587594
#endif /* MLK_SYS_RISCV64 */
588595
#endif /* MLK_CONFIG_USE_NATIVE_BACKEND_ARITH */
589596
#endif /* !MLK_CONFIG_MONOBUILD_KEEP_SHARED_HEADERS */
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Copyright (c) The mlkem-native project authors
3+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
4+
*/
5+
6+
/* NOTE: You can remove this file unless you compile with MLKEM_DEBUG. */
7+
8+
#include "../../../common.h"
9+
10+
#if defined(MLK_ARITH_BACKEND_RISCV64) && \
11+
!defined(MLK_CONFIG_MULTILEVEL_NO_SHARED) && defined(MLKEM_DEBUG)
12+
13+
#include <riscv_vector.h>
14+
#include <stdio.h>
15+
#include <stdlib.h>
16+
#include "../../../debug.h"
17+
#include "rv64v_settings.h"
18+
19+
#define MLK_DEBUG_ERROR_HEADER "[ERROR:%s:%04d] "
20+
21+
/*************************************************
22+
* Name: mlk_debug_check_bounds_int16m1
23+
*
24+
* Description: Check whether values in a vint16m1_t vector
25+
* are within specified bounds.
26+
*
27+
* Implementation: Extract vector elements to a temporary array
28+
* and reuse existing array bounds checking.
29+
**************************************************/
30+
void mlk_debug_check_bounds_int16m1(const char *file, int line, vint16m1_t vec,
31+
size_t vl, int lower_bound_exclusive,
32+
int upper_bound_exclusive)
33+
{
34+
/* Allocate temporary array to store vector elements
35+
* We use the maximum possible vector length to be safe */
36+
int16_t temp_array[MLK_RVV_E16M1_VL];
37+
38+
/* Store vector elements to temporary array for inspection */
39+
__riscv_vse16_v_i16m1(temp_array, vec, vl);
40+
41+
/* Reuse existing array bounds checking function */
42+
mlk_debug_check_bounds(file, line, temp_array, (unsigned)vl,
43+
lower_bound_exclusive, upper_bound_exclusive);
44+
}
45+
46+
/*************************************************
47+
* Name: mlk_debug_check_bounds_int16m2
48+
*
49+
* Description: Check whether values in a vint16m2_t vector
50+
* are within specified bounds.
51+
*
52+
* Implementation: Extract vector elements to a temporary array
53+
* and reuse existing array bounds checking.
54+
**************************************************/
55+
void mlk_debug_check_bounds_int16m2(const char *file, int line, vint16m2_t vec,
56+
size_t vl, int lower_bound_exclusive,
57+
int upper_bound_exclusive)
58+
{
59+
/* Allocate temporary array to store vector elements
60+
* m2 vectors hold 2x the elements of m1 vectors */
61+
int16_t temp_array[2 * MLK_RVV_E16M1_VL];
62+
63+
/* Store vector elements to temporary array for inspection */
64+
__riscv_vse16_v_i16m2(temp_array, vec, 2 * vl);
65+
66+
/* Reuse existing array bounds checking function for all elements */
67+
mlk_debug_check_bounds(file, line, temp_array, (unsigned)(2 * vl),
68+
lower_bound_exclusive, upper_bound_exclusive);
69+
}
70+
71+
#else /* MLK_ARITH_BACKEND_RISCV64 && !MLK_CONFIG_MULTILEVEL_NO_SHARED && \
72+
MLKEM_DEBUG */
73+
74+
MLK_EMPTY_CU(rv64v_debug)
75+
76+
#endif /* !(MLK_ARITH_BACKEND_RISCV64 && !MLK_CONFIG_MULTILEVEL_NO_SHARED && \
77+
MLKEM_DEBUG) */
78+
79+
/* To facilitate single-compilation-unit (SCU) builds, undefine all macros.
80+
* Don't modify by hand -- this is auto-generated by scripts/autogen. */
81+
#undef MLK_DEBUG_ERROR_HEADER

mlkem/src/native/riscv64/src/rv64v_poly.c

Lines changed: 74 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ static inline vint16m1_t fq_redc(vint16m1_t rh, vint16m1_t rl, size_t vl)
2121

2222
t = __riscv_vmul_vx_i16m1(rl, MLK_RVV_QI, vl); /* t = l * -Q^-1 */
2323
t = __riscv_vmulh_vx_i16m1(t, MLKEM_Q, vl); /* t = (t*Q) / R */
24-
c = __riscv_vmsne_vx_i16m1_b16(rl, 0, vl); /* c = l == 0 */
24+
c = __riscv_vmsne_vx_i16m1_b16(rl, 0, vl); /* c = (l != 0) */
2525
t = __riscv_vadc_vvm_i16m1(t, rh, c, vl); /* t += h + c */
2626

2727
return t;
@@ -55,6 +55,7 @@ static inline vint16m1_t fq_barrett(vint16m1_t a, size_t vl)
5555
t = __riscv_vmul_vx_i16m1(t, MLKEM_Q, vl);
5656
t = __riscv_vsub_vv_i16m1(a, t, vl);
5757

58+
mlk_assert_abs_bound_int16m1(t, vl, MLKEM_Q_HALF);
5859
return t;
5960
}
6061

@@ -66,6 +67,7 @@ static inline vint16m1_t fq_cadd(vint16m1_t rx, size_t vl)
6667

6768
bn = __riscv_vmslt_vx_i16m1_b16(rx, 0, vl); /* if x < 0: */
6869
rx = __riscv_vadd_vx_i16m1_mu(bn, rx, rx, MLKEM_Q, vl); /* x += Q */
70+
6971
return rx;
7072
}
7173

@@ -75,8 +77,9 @@ static inline vint16m1_t fq_csub(vint16m1_t rx, size_t vl)
7577
{
7678
vbool16_t bn;
7779

78-
bn = __riscv_vmsge_vx_i16m1_b16(rx, MLKEM_Q, vl); /* if x >= 0: */
80+
bn = __riscv_vmsge_vx_i16m1_b16(rx, MLKEM_Q, vl); /* if x >= Q: */
7981
rx = __riscv_vsub_vx_i16m1_mu(bn, rx, rx, MLKEM_Q, vl); /* x -= Q */
82+
8083
return rx;
8184
}
8285

@@ -106,7 +109,12 @@ static inline vint16m1_t fq_mul_vx(vint16m1_t rx, int16_t ry, size_t vl)
106109

107110
static inline vint16m1_t fq_mulq_vx(vint16m1_t rx, int16_t ry, size_t vl)
108111
{
109-
return fq_cadd(fq_mul_vx(rx, ry, vl), vl);
112+
vint16m1_t result;
113+
114+
result = fq_cadd(fq_mul_vx(rx, ry, vl), vl);
115+
116+
mlk_assert_bound_int16m1(result, vl, 0, MLKEM_Q);
117+
return result;
110118
}
111119

112120
/* create a permutation for swapping index bits a and b, a < b */
@@ -142,18 +150,30 @@ static vuint16m2_t bitswap_perm(unsigned a, unsigned b, size_t vl)
142150

143151
/* forward butterfly operation */
144152

145-
#define MLK_RVV_BFLY_FX(u0, u1, ut, uc, vl) \
146-
{ \
147-
ut = fq_mul_vx(u1, uc, vl); \
148-
u1 = __riscv_vsub_vv_i16m1(u0, ut, vl); \
149-
u0 = __riscv_vadd_vv_i16m1(u0, ut, vl); \
153+
#define MLK_RVV_BFLY_FX(u0, u1, ut, uc, vl, layer) \
154+
{ \
155+
mlk_assert_abs_bound(&uc, 1, MLKEM_Q_HALF); \
156+
\
157+
ut = fq_mul_vx(u1, uc, vl); \
158+
u1 = __riscv_vsub_vv_i16m1(u0, ut, vl); \
159+
u0 = __riscv_vadd_vv_i16m1(u0, ut, vl); \
160+
\
161+
/* mlk_assert_abs_bound_int16m1(u0, vl, (layer + 1) * MLKEM_Q); */ \
162+
/* mlk_assert_abs_bound_int16m1(u1, vl, (layer + 1) * MLKEM_Q); */ \
163+
/* mlk_assert_abs_bound_int16m1(ut, vl, MLKEM_Q); */ \
150164
}
151165

152-
#define MLK_RVV_BFLY_FV(u0, u1, ut, uc, vl) \
153-
{ \
154-
ut = fq_mul_vv(u1, uc, vl); \
155-
u1 = __riscv_vsub_vv_i16m1(u0, ut, vl); \
156-
u0 = __riscv_vadd_vv_i16m1(u0, ut, vl); \
166+
#define MLK_RVV_BFLY_FV(u0, u1, ut, uc, vl, layer) \
167+
{ \
168+
mlk_assert_abs_bound_int16m1(uc, vl, MLKEM_Q_HALF); \
169+
\
170+
ut = fq_mul_vv(u1, uc, vl); \
171+
u1 = __riscv_vsub_vv_i16m1(u0, ut, vl); \
172+
u0 = __riscv_vadd_vv_i16m1(u0, ut, vl); \
173+
\
174+
/* mlk_assert_abs_bound_int16m1(ut, vl, MLKEM_Q); */ \
175+
/* mlk_assert_abs_bound_int16m1(u0, vl, (layer + 1) * MLKEM_Q); */ \
176+
/* mlk_assert_abs_bound_int16m1(u1, vl, (layer + 1) * MLKEM_Q); */ \
157177
}
158178

159179
static vint16m2_t mlk_rv64v_ntt2(vint16m2_t vp, vint16m1_t cz)
@@ -185,7 +205,7 @@ static vint16m2_t mlk_rv64v_ntt2(vint16m2_t vp, vint16m1_t cz)
185205
t1 = __riscv_vget_v_i16m2_i16m1(vp, 1);
186206

187207
c0 = __riscv_vrgather_vv_i16m1(cz, cs8, vl);
188-
MLK_RVV_BFLY_FV(t0, t1, vt, c0, vl);
208+
MLK_RVV_BFLY_FV(t0, t1, vt, c0, vl, 5);
189209

190210
/* swap 4 */
191211
vp = __riscv_vcreate_v_i16m1_i16m2(t0, t1);
@@ -194,7 +214,7 @@ static vint16m2_t mlk_rv64v_ntt2(vint16m2_t vp, vint16m1_t cz)
194214
t1 = __riscv_vget_v_i16m2_i16m1(vp, 1);
195215

196216
c0 = __riscv_vrgather_vv_i16m1(cz, cs4, vl);
197-
MLK_RVV_BFLY_FV(t0, t1, vt, c0, vl);
217+
MLK_RVV_BFLY_FV(t0, t1, vt, c0, vl, 6);
198218

199219
/* swap 2 */
200220
vp = __riscv_vcreate_v_i16m1_i16m2(t0, t1);
@@ -203,7 +223,7 @@ static vint16m2_t mlk_rv64v_ntt2(vint16m2_t vp, vint16m1_t cz)
203223
t1 = __riscv_vget_v_i16m2_i16m1(vp, 1);
204224

205225
c0 = __riscv_vrgather_vv_i16m1(cz, cs2, vl);
206-
MLK_RVV_BFLY_FV(t0, t1, vt, c0, vl);
226+
MLK_RVV_BFLY_FV(t0, t1, vt, c0, vl, 7);
207227

208228
/* normalize */
209229
t0 = fq_mulq_vx(t0, MLK_RVV_MONT_R1, vl);
@@ -253,41 +273,41 @@ void mlk_rv64v_poly_ntt(int16_t *r)
253273
ve = __riscv_vle16_v_i16m1(&r[0xe0], vl);
254274
vf = __riscv_vle16_v_i16m1(&r[0xf0], vl);
255275

256-
MLK_RVV_BFLY_FX(v0, v8, vt, zeta[0x01], vl);
257-
MLK_RVV_BFLY_FX(v1, v9, vt, zeta[0x01], vl);
258-
MLK_RVV_BFLY_FX(v2, va, vt, zeta[0x01], vl);
259-
MLK_RVV_BFLY_FX(v3, vb, vt, zeta[0x01], vl);
260-
MLK_RVV_BFLY_FX(v4, vc, vt, zeta[0x01], vl);
261-
MLK_RVV_BFLY_FX(v5, vd, vt, zeta[0x01], vl);
262-
MLK_RVV_BFLY_FX(v6, ve, vt, zeta[0x01], vl);
263-
MLK_RVV_BFLY_FX(v7, vf, vt, zeta[0x01], vl);
264-
265-
MLK_RVV_BFLY_FX(v0, v4, vt, zeta[0x10], vl);
266-
MLK_RVV_BFLY_FX(v1, v5, vt, zeta[0x10], vl);
267-
MLK_RVV_BFLY_FX(v2, v6, vt, zeta[0x10], vl);
268-
MLK_RVV_BFLY_FX(v3, v7, vt, zeta[0x10], vl);
269-
MLK_RVV_BFLY_FX(v8, vc, vt, zeta[0x11], vl);
270-
MLK_RVV_BFLY_FX(v9, vd, vt, zeta[0x11], vl);
271-
MLK_RVV_BFLY_FX(va, ve, vt, zeta[0x11], vl);
272-
MLK_RVV_BFLY_FX(vb, vf, vt, zeta[0x11], vl);
273-
274-
MLK_RVV_BFLY_FX(v0, v2, vt, zeta[0x20], vl);
275-
MLK_RVV_BFLY_FX(v1, v3, vt, zeta[0x20], vl);
276-
MLK_RVV_BFLY_FX(v4, v6, vt, zeta[0x21], vl);
277-
MLK_RVV_BFLY_FX(v5, v7, vt, zeta[0x21], vl);
278-
MLK_RVV_BFLY_FX(v8, va, vt, zeta[0x30], vl);
279-
MLK_RVV_BFLY_FX(v9, vb, vt, zeta[0x30], vl);
280-
MLK_RVV_BFLY_FX(vc, ve, vt, zeta[0x31], vl);
281-
MLK_RVV_BFLY_FX(vd, vf, vt, zeta[0x31], vl);
282-
283-
MLK_RVV_BFLY_FX(v0, v1, vt, zeta[0x40], vl);
284-
MLK_RVV_BFLY_FX(v2, v3, vt, zeta[0x41], vl);
285-
MLK_RVV_BFLY_FX(v4, v5, vt, zeta[0x50], vl);
286-
MLK_RVV_BFLY_FX(v6, v7, vt, zeta[0x51], vl);
287-
MLK_RVV_BFLY_FX(v8, v9, vt, zeta[0x60], vl);
288-
MLK_RVV_BFLY_FX(va, vb, vt, zeta[0x61], vl);
289-
MLK_RVV_BFLY_FX(vc, vd, vt, zeta[0x70], vl);
290-
MLK_RVV_BFLY_FX(ve, vf, vt, zeta[0x71], vl);
276+
MLK_RVV_BFLY_FX(v0, v8, vt, zeta[0x01], vl, 1);
277+
MLK_RVV_BFLY_FX(v1, v9, vt, zeta[0x01], vl, 1);
278+
MLK_RVV_BFLY_FX(v2, va, vt, zeta[0x01], vl, 1);
279+
MLK_RVV_BFLY_FX(v3, vb, vt, zeta[0x01], vl, 1);
280+
MLK_RVV_BFLY_FX(v4, vc, vt, zeta[0x01], vl, 1);
281+
MLK_RVV_BFLY_FX(v5, vd, vt, zeta[0x01], vl, 1);
282+
MLK_RVV_BFLY_FX(v6, ve, vt, zeta[0x01], vl, 1);
283+
MLK_RVV_BFLY_FX(v7, vf, vt, zeta[0x01], vl, 1);
284+
285+
MLK_RVV_BFLY_FX(v0, v4, vt, zeta[0x10], vl, 2);
286+
MLK_RVV_BFLY_FX(v1, v5, vt, zeta[0x10], vl, 2);
287+
MLK_RVV_BFLY_FX(v2, v6, vt, zeta[0x10], vl, 2);
288+
MLK_RVV_BFLY_FX(v3, v7, vt, zeta[0x10], vl, 2);
289+
MLK_RVV_BFLY_FX(v8, vc, vt, zeta[0x11], vl, 2);
290+
MLK_RVV_BFLY_FX(v9, vd, vt, zeta[0x11], vl, 2);
291+
MLK_RVV_BFLY_FX(va, ve, vt, zeta[0x11], vl, 2);
292+
MLK_RVV_BFLY_FX(vb, vf, vt, zeta[0x11], vl, 2);
293+
294+
MLK_RVV_BFLY_FX(v0, v2, vt, zeta[0x20], vl, 3);
295+
MLK_RVV_BFLY_FX(v1, v3, vt, zeta[0x20], vl, 3);
296+
MLK_RVV_BFLY_FX(v4, v6, vt, zeta[0x21], vl, 3);
297+
MLK_RVV_BFLY_FX(v5, v7, vt, zeta[0x21], vl, 3);
298+
MLK_RVV_BFLY_FX(v8, va, vt, zeta[0x30], vl, 3);
299+
MLK_RVV_BFLY_FX(v9, vb, vt, zeta[0x30], vl, 3);
300+
MLK_RVV_BFLY_FX(vc, ve, vt, zeta[0x31], vl, 3);
301+
MLK_RVV_BFLY_FX(vd, vf, vt, zeta[0x31], vl, 3);
302+
303+
MLK_RVV_BFLY_FX(v0, v1, vt, zeta[0x40], vl, 4);
304+
MLK_RVV_BFLY_FX(v2, v3, vt, zeta[0x41], vl, 4);
305+
MLK_RVV_BFLY_FX(v4, v5, vt, zeta[0x50], vl, 4);
306+
MLK_RVV_BFLY_FX(v6, v7, vt, zeta[0x51], vl, 4);
307+
MLK_RVV_BFLY_FX(v8, v9, vt, zeta[0x60], vl, 4);
308+
MLK_RVV_BFLY_FX(va, vb, vt, zeta[0x61], vl, 4);
309+
MLK_RVV_BFLY_FX(vc, vd, vt, zeta[0x70], vl, 4);
310+
MLK_RVV_BFLY_FX(ve, vf, vt, zeta[0x71], vl, 4);
291311

292312
__riscv_vse16_v_i16m2(
293313
&r[0x00], mlk_rv64v_ntt2(__riscv_vcreate_v_i16m1_i16m2(v0, v1), z0), vl2);
@@ -614,9 +634,9 @@ void mlk_rv64v_poly_tomont(int16_t *r)
614634

615635
for (size_t i = 0; i < MLKEM_N; i += vl)
616636
{
617-
__riscv_vse16_v_i16m1(
618-
&r[i], fq_mul_vx(__riscv_vle16_v_i16m1(&r[i], vl), MLK_RVV_MONT_R2, vl),
619-
vl);
637+
vint16m1_t vec = __riscv_vle16_v_i16m1(&r[i], vl);
638+
vec = fq_mul_vx(vec, MLK_RVV_MONT_R2, vl);
639+
__riscv_vse16_v_i16m1(&r[i], vec, vl);
620640
}
621641
}
622642

@@ -638,8 +658,6 @@ void mlk_rv64v_poly_reduce(int16_t *r)
638658
{
639659
vt = __riscv_vle16_v_i16m1(&r[i], vl);
640660
vt = fq_barrett(vt, vl);
641-
642-
/* make positive */
643661
vt = fq_cadd(vt, vl);
644662
__riscv_vse16_v_i16m1(&r[i], vt, vl);
645663
}

0 commit comments

Comments
 (0)