@@ -21,7 +21,7 @@ static inline vint16m1_t fq_redc(vint16m1_t rh, vint16m1_t rl, size_t vl)
21
21
22
22
t = __riscv_vmul_vx_i16m1 (rl , MLK_RVV_QI , vl ); /* t = l * -Q^-1 */
23
23
t = __riscv_vmulh_vx_i16m1 (t , MLKEM_Q , vl ); /* t = (t*Q) / R */
24
- c = __riscv_vmsne_vx_i16m1_b16 (rl , 0 , vl ); /* c = l == 0 */
24
+ c = __riscv_vmsne_vx_i16m1_b16 (rl , 0 , vl ); /* c = (l != 0) */
25
25
t = __riscv_vadc_vvm_i16m1 (t , rh , c , vl ); /* t += h + c */
26
26
27
27
return t ;
@@ -55,6 +55,7 @@ static inline vint16m1_t fq_barrett(vint16m1_t a, size_t vl)
55
55
t = __riscv_vmul_vx_i16m1 (t , MLKEM_Q , vl );
56
56
t = __riscv_vsub_vv_i16m1 (a , t , vl );
57
57
58
+ mlk_assert_abs_bound_int16m1 (t , vl , MLKEM_Q_HALF );
58
59
return t ;
59
60
}
60
61
@@ -66,6 +67,7 @@ static inline vint16m1_t fq_cadd(vint16m1_t rx, size_t vl)
66
67
67
68
bn = __riscv_vmslt_vx_i16m1_b16 (rx , 0 , vl ); /* if x < 0: */
68
69
rx = __riscv_vadd_vx_i16m1_mu (bn , rx , rx , MLKEM_Q , vl ); /* x += Q */
70
+
69
71
return rx ;
70
72
}
71
73
@@ -75,8 +77,9 @@ static inline vint16m1_t fq_csub(vint16m1_t rx, size_t vl)
75
77
{
76
78
vbool16_t bn ;
77
79
78
- bn = __riscv_vmsge_vx_i16m1_b16 (rx , MLKEM_Q , vl ); /* if x >= 0 : */
80
+ bn = __riscv_vmsge_vx_i16m1_b16 (rx , MLKEM_Q , vl ); /* if x >= Q : */
79
81
rx = __riscv_vsub_vx_i16m1_mu (bn , rx , rx , MLKEM_Q , vl ); /* x -= Q */
82
+
80
83
return rx ;
81
84
}
82
85
@@ -106,7 +109,12 @@ static inline vint16m1_t fq_mul_vx(vint16m1_t rx, int16_t ry, size_t vl)
106
109
107
110
static inline vint16m1_t fq_mulq_vx (vint16m1_t rx , int16_t ry , size_t vl )
108
111
{
109
- return fq_cadd (fq_mul_vx (rx , ry , vl ), vl );
112
+ vint16m1_t result ;
113
+
114
+ result = fq_cadd (fq_mul_vx (rx , ry , vl ), vl );
115
+
116
+ mlk_assert_bound_int16m1 (result , vl , 0 , MLKEM_Q );
117
+ return result ;
110
118
}
111
119
112
120
/* create a permutation for swapping index bits a and b, a < b */
@@ -142,18 +150,30 @@ static vuint16m2_t bitswap_perm(unsigned a, unsigned b, size_t vl)
142
150
143
151
/* forward butterfly operation */
144
152
145
- #define MLK_RVV_BFLY_FX (u0 , u1 , ut , uc , vl ) \
146
- { \
147
- ut = fq_mul_vx(u1, uc, vl); \
148
- u1 = __riscv_vsub_vv_i16m1(u0, ut, vl); \
149
- u0 = __riscv_vadd_vv_i16m1(u0, ut, vl); \
153
+ #define MLK_RVV_BFLY_FX (u0 , u1 , ut , uc , vl , layer ) \
154
+ { \
155
+ mlk_assert_abs_bound(&uc, 1, MLKEM_Q_HALF); \
156
+ \
157
+ ut = fq_mul_vx(u1, uc, vl); \
158
+ u1 = __riscv_vsub_vv_i16m1(u0, ut, vl); \
159
+ u0 = __riscv_vadd_vv_i16m1(u0, ut, vl); \
160
+ \
161
+ /* mlk_assert_abs_bound_int16m1(u0, vl, (layer + 1) * MLKEM_Q); */ \
162
+ /* mlk_assert_abs_bound_int16m1(u1, vl, (layer + 1) * MLKEM_Q); */ \
163
+ /* mlk_assert_abs_bound_int16m1(ut, vl, MLKEM_Q); */ \
150
164
}
151
165
152
- #define MLK_RVV_BFLY_FV (u0 , u1 , ut , uc , vl ) \
153
- { \
154
- ut = fq_mul_vv(u1, uc, vl); \
155
- u1 = __riscv_vsub_vv_i16m1(u0, ut, vl); \
156
- u0 = __riscv_vadd_vv_i16m1(u0, ut, vl); \
166
+ #define MLK_RVV_BFLY_FV (u0 , u1 , ut , uc , vl , layer ) \
167
+ { \
168
+ mlk_assert_abs_bound_int16m1(uc, vl, MLKEM_Q_HALF); \
169
+ \
170
+ ut = fq_mul_vv(u1, uc, vl); \
171
+ u1 = __riscv_vsub_vv_i16m1(u0, ut, vl); \
172
+ u0 = __riscv_vadd_vv_i16m1(u0, ut, vl); \
173
+ \
174
+ /* mlk_assert_abs_bound_int16m1(ut, vl, MLKEM_Q); */ \
175
+ /* mlk_assert_abs_bound_int16m1(u0, vl, (layer + 1) * MLKEM_Q); */ \
176
+ /* mlk_assert_abs_bound_int16m1(u1, vl, (layer + 1) * MLKEM_Q); */ \
157
177
}
158
178
159
179
static vint16m2_t mlk_rv64v_ntt2 (vint16m2_t vp , vint16m1_t cz )
@@ -185,7 +205,7 @@ static vint16m2_t mlk_rv64v_ntt2(vint16m2_t vp, vint16m1_t cz)
185
205
t1 = __riscv_vget_v_i16m2_i16m1 (vp , 1 );
186
206
187
207
c0 = __riscv_vrgather_vv_i16m1 (cz , cs8 , vl );
188
- MLK_RVV_BFLY_FV (t0 , t1 , vt , c0 , vl );
208
+ MLK_RVV_BFLY_FV (t0 , t1 , vt , c0 , vl , 5 );
189
209
190
210
/* swap 4 */
191
211
vp = __riscv_vcreate_v_i16m1_i16m2 (t0 , t1 );
@@ -194,7 +214,7 @@ static vint16m2_t mlk_rv64v_ntt2(vint16m2_t vp, vint16m1_t cz)
194
214
t1 = __riscv_vget_v_i16m2_i16m1 (vp , 1 );
195
215
196
216
c0 = __riscv_vrgather_vv_i16m1 (cz , cs4 , vl );
197
- MLK_RVV_BFLY_FV (t0 , t1 , vt , c0 , vl );
217
+ MLK_RVV_BFLY_FV (t0 , t1 , vt , c0 , vl , 6 );
198
218
199
219
/* swap 2 */
200
220
vp = __riscv_vcreate_v_i16m1_i16m2 (t0 , t1 );
@@ -203,7 +223,7 @@ static vint16m2_t mlk_rv64v_ntt2(vint16m2_t vp, vint16m1_t cz)
203
223
t1 = __riscv_vget_v_i16m2_i16m1 (vp , 1 );
204
224
205
225
c0 = __riscv_vrgather_vv_i16m1 (cz , cs2 , vl );
206
- MLK_RVV_BFLY_FV (t0 , t1 , vt , c0 , vl );
226
+ MLK_RVV_BFLY_FV (t0 , t1 , vt , c0 , vl , 7 );
207
227
208
228
/* normalize */
209
229
t0 = fq_mulq_vx (t0 , MLK_RVV_MONT_R1 , vl );
@@ -253,41 +273,41 @@ void mlk_rv64v_poly_ntt(int16_t *r)
253
273
ve = __riscv_vle16_v_i16m1 (& r [0xe0 ], vl );
254
274
vf = __riscv_vle16_v_i16m1 (& r [0xf0 ], vl );
255
275
256
- MLK_RVV_BFLY_FX (v0 , v8 , vt , zeta [0x01 ], vl );
257
- MLK_RVV_BFLY_FX (v1 , v9 , vt , zeta [0x01 ], vl );
258
- MLK_RVV_BFLY_FX (v2 , va , vt , zeta [0x01 ], vl );
259
- MLK_RVV_BFLY_FX (v3 , vb , vt , zeta [0x01 ], vl );
260
- MLK_RVV_BFLY_FX (v4 , vc , vt , zeta [0x01 ], vl );
261
- MLK_RVV_BFLY_FX (v5 , vd , vt , zeta [0x01 ], vl );
262
- MLK_RVV_BFLY_FX (v6 , ve , vt , zeta [0x01 ], vl );
263
- MLK_RVV_BFLY_FX (v7 , vf , vt , zeta [0x01 ], vl );
264
-
265
- MLK_RVV_BFLY_FX (v0 , v4 , vt , zeta [0x10 ], vl );
266
- MLK_RVV_BFLY_FX (v1 , v5 , vt , zeta [0x10 ], vl );
267
- MLK_RVV_BFLY_FX (v2 , v6 , vt , zeta [0x10 ], vl );
268
- MLK_RVV_BFLY_FX (v3 , v7 , vt , zeta [0x10 ], vl );
269
- MLK_RVV_BFLY_FX (v8 , vc , vt , zeta [0x11 ], vl );
270
- MLK_RVV_BFLY_FX (v9 , vd , vt , zeta [0x11 ], vl );
271
- MLK_RVV_BFLY_FX (va , ve , vt , zeta [0x11 ], vl );
272
- MLK_RVV_BFLY_FX (vb , vf , vt , zeta [0x11 ], vl );
273
-
274
- MLK_RVV_BFLY_FX (v0 , v2 , vt , zeta [0x20 ], vl );
275
- MLK_RVV_BFLY_FX (v1 , v3 , vt , zeta [0x20 ], vl );
276
- MLK_RVV_BFLY_FX (v4 , v6 , vt , zeta [0x21 ], vl );
277
- MLK_RVV_BFLY_FX (v5 , v7 , vt , zeta [0x21 ], vl );
278
- MLK_RVV_BFLY_FX (v8 , va , vt , zeta [0x30 ], vl );
279
- MLK_RVV_BFLY_FX (v9 , vb , vt , zeta [0x30 ], vl );
280
- MLK_RVV_BFLY_FX (vc , ve , vt , zeta [0x31 ], vl );
281
- MLK_RVV_BFLY_FX (vd , vf , vt , zeta [0x31 ], vl );
282
-
283
- MLK_RVV_BFLY_FX (v0 , v1 , vt , zeta [0x40 ], vl );
284
- MLK_RVV_BFLY_FX (v2 , v3 , vt , zeta [0x41 ], vl );
285
- MLK_RVV_BFLY_FX (v4 , v5 , vt , zeta [0x50 ], vl );
286
- MLK_RVV_BFLY_FX (v6 , v7 , vt , zeta [0x51 ], vl );
287
- MLK_RVV_BFLY_FX (v8 , v9 , vt , zeta [0x60 ], vl );
288
- MLK_RVV_BFLY_FX (va , vb , vt , zeta [0x61 ], vl );
289
- MLK_RVV_BFLY_FX (vc , vd , vt , zeta [0x70 ], vl );
290
- MLK_RVV_BFLY_FX (ve , vf , vt , zeta [0x71 ], vl );
276
+ MLK_RVV_BFLY_FX (v0 , v8 , vt , zeta [0x01 ], vl , 1 );
277
+ MLK_RVV_BFLY_FX (v1 , v9 , vt , zeta [0x01 ], vl , 1 );
278
+ MLK_RVV_BFLY_FX (v2 , va , vt , zeta [0x01 ], vl , 1 );
279
+ MLK_RVV_BFLY_FX (v3 , vb , vt , zeta [0x01 ], vl , 1 );
280
+ MLK_RVV_BFLY_FX (v4 , vc , vt , zeta [0x01 ], vl , 1 );
281
+ MLK_RVV_BFLY_FX (v5 , vd , vt , zeta [0x01 ], vl , 1 );
282
+ MLK_RVV_BFLY_FX (v6 , ve , vt , zeta [0x01 ], vl , 1 );
283
+ MLK_RVV_BFLY_FX (v7 , vf , vt , zeta [0x01 ], vl , 1 );
284
+
285
+ MLK_RVV_BFLY_FX (v0 , v4 , vt , zeta [0x10 ], vl , 2 );
286
+ MLK_RVV_BFLY_FX (v1 , v5 , vt , zeta [0x10 ], vl , 2 );
287
+ MLK_RVV_BFLY_FX (v2 , v6 , vt , zeta [0x10 ], vl , 2 );
288
+ MLK_RVV_BFLY_FX (v3 , v7 , vt , zeta [0x10 ], vl , 2 );
289
+ MLK_RVV_BFLY_FX (v8 , vc , vt , zeta [0x11 ], vl , 2 );
290
+ MLK_RVV_BFLY_FX (v9 , vd , vt , zeta [0x11 ], vl , 2 );
291
+ MLK_RVV_BFLY_FX (va , ve , vt , zeta [0x11 ], vl , 2 );
292
+ MLK_RVV_BFLY_FX (vb , vf , vt , zeta [0x11 ], vl , 2 );
293
+
294
+ MLK_RVV_BFLY_FX (v0 , v2 , vt , zeta [0x20 ], vl , 3 );
295
+ MLK_RVV_BFLY_FX (v1 , v3 , vt , zeta [0x20 ], vl , 3 );
296
+ MLK_RVV_BFLY_FX (v4 , v6 , vt , zeta [0x21 ], vl , 3 );
297
+ MLK_RVV_BFLY_FX (v5 , v7 , vt , zeta [0x21 ], vl , 3 );
298
+ MLK_RVV_BFLY_FX (v8 , va , vt , zeta [0x30 ], vl , 3 );
299
+ MLK_RVV_BFLY_FX (v9 , vb , vt , zeta [0x30 ], vl , 3 );
300
+ MLK_RVV_BFLY_FX (vc , ve , vt , zeta [0x31 ], vl , 3 );
301
+ MLK_RVV_BFLY_FX (vd , vf , vt , zeta [0x31 ], vl , 3 );
302
+
303
+ MLK_RVV_BFLY_FX (v0 , v1 , vt , zeta [0x40 ], vl , 4 );
304
+ MLK_RVV_BFLY_FX (v2 , v3 , vt , zeta [0x41 ], vl , 4 );
305
+ MLK_RVV_BFLY_FX (v4 , v5 , vt , zeta [0x50 ], vl , 4 );
306
+ MLK_RVV_BFLY_FX (v6 , v7 , vt , zeta [0x51 ], vl , 4 );
307
+ MLK_RVV_BFLY_FX (v8 , v9 , vt , zeta [0x60 ], vl , 4 );
308
+ MLK_RVV_BFLY_FX (va , vb , vt , zeta [0x61 ], vl , 4 );
309
+ MLK_RVV_BFLY_FX (vc , vd , vt , zeta [0x70 ], vl , 4 );
310
+ MLK_RVV_BFLY_FX (ve , vf , vt , zeta [0x71 ], vl , 4 );
291
311
292
312
__riscv_vse16_v_i16m2 (
293
313
& r [0x00 ], mlk_rv64v_ntt2 (__riscv_vcreate_v_i16m1_i16m2 (v0 , v1 ), z0 ), vl2 );
@@ -614,9 +634,9 @@ void mlk_rv64v_poly_tomont(int16_t *r)
614
634
615
635
for (size_t i = 0 ; i < MLKEM_N ; i += vl )
616
636
{
617
- __riscv_vse16_v_i16m1 (
618
- & r [ i ], fq_mul_vx (__riscv_vle16_v_i16m1 ( & r [ i ], vl ), MLK_RVV_MONT_R2 , vl ),
619
- vl );
637
+ vint16m1_t vec = __riscv_vle16_v_i16m1 ( & r [ i ], vl );
638
+ vec = fq_mul_vx (vec , MLK_RVV_MONT_R2 , vl );
639
+ __riscv_vse16_v_i16m1 ( & r [ i ], vec , vl );
620
640
}
621
641
}
622
642
@@ -638,8 +658,6 @@ void mlk_rv64v_poly_reduce(int16_t *r)
638
658
{
639
659
vt = __riscv_vle16_v_i16m1 (& r [i ], vl );
640
660
vt = fq_barrett (vt , vl );
641
-
642
- /* make positive */
643
661
vt = fq_cadd (vt , vl );
644
662
__riscv_vse16_v_i16m1 (& r [i ], vt , vl );
645
663
}
0 commit comments