@@ -4366,26 +4366,32 @@ struct test_flash_attn_ext : public test_case {
4366
4366
const int64_t hsk_padded = GGML_PAD (hsk, ggml_blck_size (type_KV));
4367
4367
const int64_t hsv_padded = GGML_PAD (hsv, ggml_blck_size (type_KV));
4368
4368
4369
- auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
4369
+ auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, bool is_view ) -> ggml_tensor * {
4370
4370
int64_t ne[4 ] = {ne0, ne1, ne2, ne3};
4371
4371
int64_t ne_perm[4 ];
4372
4372
for (int i = 0 ; i < 4 ; ++i) {
4373
4373
ne_perm[permute[i]] = ne[i];
4374
4374
}
4375
- ggml_tensor * t = ggml_new_tensor_4d (ctx, type, ne_perm[0 ], ne_perm[1 ], ne_perm[2 ], ne_perm[3 ]);
4375
+ ggml_tensor * t;
4376
+ if (is_view) {
4377
+ ggml_tensor * t0 = ggml_new_tensor_4d (ctx, type, ne_perm[0 ], 2 *ne_perm[1 ], ne_perm[2 ], ne_perm[3 ]);
4378
+ t = ggml_view_4d (ctx, t0, ne_perm[0 ], ne_perm[1 ], ne_perm[2 ], ne_perm[3 ], t0->nb [1 ], t0->nb [2 ], t0->nb [3 ], 0 );
4379
+ } else {
4380
+ t = ggml_new_tensor_4d (ctx, type, ne_perm[0 ], ne_perm[1 ], ne_perm[2 ], ne_perm[3 ]);
4381
+ }
4376
4382
if (permute != std::array<int32_t , 4 >{0 , 1 , 2 , 3 }) {
4377
4383
t = ggml_permute (ctx, t, permute[0 ], permute[1 ], permute[2 ], permute[3 ]);
4378
4384
}
4379
4385
return t;
4380
4386
};
4381
4387
4382
- ggml_tensor * q = create_permuted (GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0 ], nr23[1 ]);
4388
+ ggml_tensor * q = create_permuted (GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0 ], nr23[1 ], false );
4383
4389
ggml_set_name (q, " q" );
4384
4390
4385
- ggml_tensor * k = create_permuted (type_KV, hsk_padded, kv, nh, nr23[1 ]);
4391
+ ggml_tensor * k = create_permuted (type_KV, hsk_padded, kv, nh, nr23[1 ], true ); // the K tensor is usually a view of the K cache
4386
4392
ggml_set_name (k, " k" );
4387
4393
4388
- ggml_tensor * v = create_permuted (type_KV, hsv_padded, kv, nh, nr23[1 ]);
4394
+ ggml_tensor * v = create_permuted (type_KV, hsv_padded, kv, nh, nr23[1 ], true ); // the V tensor is usually a view of the V cache
4389
4395
ggml_set_name (v, " v" );
4390
4396
4391
4397
ggml_tensor * m = nullptr ;
0 commit comments