Skip to content

Commit 18f3b5f

Browse files
committed
tests : add non-cont K,V FA tests
ggml-ci
1 parent 7233358 commit 18f3b5f

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

tests/test-backend-ops.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4366,26 +4366,32 @@ struct test_flash_attn_ext : public test_case {
43664366
const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
43674367
const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV));
43684368

4369-
auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
4369+
auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, bool is_view) -> ggml_tensor * {
43704370
int64_t ne[4] = {ne0, ne1, ne2, ne3};
43714371
int64_t ne_perm[4];
43724372
for (int i = 0; i < 4; ++i) {
43734373
ne_perm[permute[i]] = ne[i];
43744374
}
4375-
ggml_tensor * t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]);
4375+
ggml_tensor * t;
4376+
if (is_view) {
4377+
ggml_tensor * t0 = ggml_new_tensor_4d(ctx, type, ne_perm[0], 2*ne_perm[1], ne_perm[2], ne_perm[3]);
4378+
t = ggml_view_4d(ctx, t0, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3], t0->nb[1], t0->nb[2], t0->nb[3], 0);
4379+
} else {
4380+
t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]);
4381+
}
43764382
if (permute != std::array<int32_t, 4>{0, 1, 2, 3}) {
43774383
t = ggml_permute(ctx, t, permute[0], permute[1], permute[2], permute[3]);
43784384
}
43794385
return t;
43804386
};
43814387

4382-
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1]);
4388+
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1], false);
43834389
ggml_set_name(q, "q");
43844390

4385-
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1]);
4391+
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1], true); // the K tensor is usually a view of the K cache
43864392
ggml_set_name(k, "k");
43874393

4388-
ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1]);
4394+
ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
43894395
ggml_set_name(v, "v");
43904396

43914397
ggml_tensor * m = nullptr;

0 commit comments

Comments
 (0)