From ac1497945743c3d508d3c5f754238823acaf0fe9 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 22 Jul 2025 21:38:23 -0400 Subject: [PATCH 1/2] graph : avoid creating redundant s_copy views --- src/llama-graph.cpp | 38 +++++++++++++++++++++++--------------- src/llama-graph.h | 16 ++++++++++------ 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b63a41053b488..4873cb4f6103a 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1561,16 +1561,17 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif ggml_tensor * llm_graph_context::build_rs( ggml_tensor * s, - ggml_tensor * state_copy, + ggml_tensor * state_copy_main, + ggml_tensor * state_copy_extra, int32_t state_size, int32_t n_seqs, - uint32_t n_kv, - uint32_t kv_head, - uint32_t kv_size, + uint32_t n_rs, + uint32_t rs_head, + uint32_t rs_size, int32_t rs_zero, const llm_graph_get_rows_fn & get_state_rows) const { - ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size); // Clear a single state which will then be copied to the other cleared states. // Note that this is a no-op when the view is zero-sized. @@ -1578,39 +1579,44 @@ ggml_tensor * llm_graph_context::build_rs( ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // {state_size, kv_size} -> {state_size, n_seqs} - ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs + // {state_size, rs_size} -> {state_size, n_seqs} + ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main); ggml_build_forward_expand(gf, output_states); - // copy extra states which won't be changed further (between n_seqs and n_kv) - ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); + // copy extra states which won't be changed further (between n_seqs and n_rs) + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra); ggml_build_forward_expand(gf, ggml_cpy(ctx0, states_extra, - ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s)))); + ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s)))); return output_states; } static std::unique_ptr build_rs_inp_impl( ggml_context * ctx0, + const llama_ubatch & ubatch, const llama_memory_recurrent_context * mctx_cur) { auto inp = std::make_unique(mctx_cur); - const auto n_rs = mctx_cur->get_n_rs(); + const int64_t n_rs = mctx_cur->get_n_rs(); + const int64_t n_seqs = ubatch.n_seqs; inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); ggml_set_input(inp->s_copy); + inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); + inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); + return inp; } llm_graph_input_rs * llm_graph_context::build_rs_inp() const { const auto * mctx_cur = static_cast(mctx); - auto inp = build_rs_inp_impl(ctx0, mctx_cur); + auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur); return (llm_graph_input_rs *) res->add_input(std::move(inp)); } @@ -1623,7 +1629,9 @@ ggml_tensor * llm_graph_context::build_rs( const llm_graph_get_rows_fn & get_state_rows) const { const auto * kv_state = inp->mctx; - return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows); + return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs, + kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), + get_state_rows); } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( @@ -1670,7 +1678,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { const auto * mctx_cur = static_cast(mctx); - auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr()); + auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); diff --git a/src/llama-graph.h b/src/llama-graph.h index a28a8c4bddad8..174eb5ed4747d 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -214,7 +214,11 @@ class llm_graph_input_rs : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; - ggml_tensor * s_copy; // I32 [kv_size] + ggml_tensor * s_copy; // I32 [n_rs] + + // views + ggml_tensor * s_copy_main; // I32 [n_seqs] + ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs] const llama_memory_recurrent_context * mctx; }; @@ -715,7 +719,6 @@ struct llm_graph_context { // recurrent // - // TODO: avoid notion of "kv" // TODO: move this implementation to llama_memory_recurrent. // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the @@ -723,12 +726,13 @@ struct llm_graph_context { // `llama_memory_recurrent` ggml_tensor * build_rs( ggml_tensor * s, - ggml_tensor * state_copy, + ggml_tensor * state_copy_main, + ggml_tensor * state_copy_extra, int32_t state_size, int32_t n_seqs, - uint32_t n_kv, - uint32_t kv_head, - uint32_t kv_size, + uint32_t n_rs, + uint32_t rs_head, + uint32_t rs_size, int32_t rs_zero, const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; From 7b660a917f1a494bf0a2e2d4f5940307c5f8643a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 23 Jul 2025 20:27:29 -0400 Subject: [PATCH 2/2] graph : comment the s_copy views --- src/llama-graph.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama-graph.h b/src/llama-graph.h index 174eb5ed4747d..b0ecaf7df89cb 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -216,7 +216,8 @@ class llm_graph_input_rs : public llm_graph_input_i { ggml_tensor * s_copy; // I32 [n_rs] - // views + // views of s_copy, computed once per graph + // and shared across layers which use build_rs ggml_tensor * s_copy_main; // I32 [n_seqs] ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]