@@ -1644,56 +1644,62 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1644
1644
1645
1645
ggml_tensor * llm_graph_context::build_rs (
1646
1646
ggml_tensor * s,
1647
- ggml_tensor * state_copy,
1647
+ ggml_tensor * state_copy_main,
1648
+ ggml_tensor * state_copy_extra,
1648
1649
int32_t state_size,
1649
1650
int32_t n_seqs,
1650
- uint32_t n_kv ,
1651
- uint32_t kv_head ,
1652
- uint32_t kv_size ,
1651
+ uint32_t n_rs ,
1652
+ uint32_t rs_head ,
1653
+ uint32_t rs_size ,
1653
1654
int32_t rs_zero,
1654
1655
const llm_graph_get_rows_fn & get_state_rows) const {
1655
1656
1656
- ggml_tensor * states = ggml_reshape_2d (ctx0, s, state_size, kv_size );
1657
+ ggml_tensor * states = ggml_reshape_2d (ctx0, s, state_size, rs_size );
1657
1658
1658
1659
// Clear a single state which will then be copied to the other cleared states.
1659
1660
// Note that this is a no-op when the view is zero-sized.
1660
1661
ggml_tensor * state_zero = ggml_view_1d (ctx0, states, state_size*(rs_zero >= 0 ), rs_zero*states->nb [1 ]*(rs_zero >= 0 ));
1661
1662
ggml_build_forward_expand (gf, ggml_scale_inplace (ctx0, state_zero, 0 ));
1662
1663
1663
1664
// copy states
1664
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1665
- // {state_size, kv_size } -> {state_size, n_seqs}
1666
- ggml_tensor * output_states = get_state_rows (ctx0, states, ggml_view_1d (ctx0, state_copy, n_seqs, 0 ) );
1665
+ // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
1666
+ // {state_size, rs_size } -> {state_size, n_seqs}
1667
+ ggml_tensor * output_states = get_state_rows (ctx0, states, state_copy_main );
1667
1668
ggml_build_forward_expand (gf, output_states);
1668
1669
1669
- // copy extra states which won't be changed further (between n_seqs and n_kv )
1670
- ggml_tensor * states_extra = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy-> nb [ 0 ]) );
1670
+ // copy extra states which won't be changed further (between n_seqs and n_rs )
1671
+ ggml_tensor * states_extra = ggml_get_rows (ctx0, states, state_copy_extra );
1671
1672
ggml_build_forward_expand (gf,
1672
1673
ggml_cpy (ctx0,
1673
1674
states_extra,
1674
- ggml_view_1d (ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size (s))));
1675
+ ggml_view_1d (ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size (s))));
1675
1676
1676
1677
return output_states;
1677
1678
}
1678
1679
1679
1680
static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl (
1680
1681
ggml_context * ctx0,
1682
+ const llama_ubatch & ubatch,
1681
1683
const llama_memory_recurrent_context * mctx_cur) {
1682
1684
1683
1685
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1684
1686
1685
- const auto n_rs = mctx_cur->get_n_rs ();
1687
+ const int64_t n_rs = mctx_cur->get_n_rs ();
1688
+ const int64_t n_seqs = ubatch.n_seqs ;
1686
1689
1687
1690
inp->s_copy = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_rs);
1688
1691
ggml_set_input (inp->s_copy );
1689
1692
1693
+ inp->s_copy_main = ggml_view_1d (ctx0, inp->s_copy , n_seqs, 0 );
1694
+ inp->s_copy_extra = ggml_view_1d (ctx0, inp->s_copy , n_rs - n_seqs, n_seqs * inp->s_copy ->nb [0 ]);
1695
+
1690
1696
return inp;
1691
1697
}
1692
1698
1693
1699
llm_graph_input_rs * llm_graph_context::build_rs_inp () const {
1694
1700
const auto * mctx_cur = static_cast <const llama_memory_recurrent_context *>(mctx);
1695
1701
1696
- auto inp = build_rs_inp_impl (ctx0, mctx_cur);
1702
+ auto inp = build_rs_inp_impl (ctx0, ubatch, mctx_cur);
1697
1703
1698
1704
return (llm_graph_input_rs *) res->add_input (std::move (inp));
1699
1705
}
@@ -1706,7 +1712,9 @@ ggml_tensor * llm_graph_context::build_rs(
1706
1712
const llm_graph_get_rows_fn & get_state_rows) const {
1707
1713
const auto * kv_state = inp->mctx ;
1708
1714
1709
- return build_rs (s, inp->s_copy , state_size, n_seqs, kv_state->get_n_rs (), kv_state->get_head (), kv_state->get_size (), kv_state->get_rs_z (), get_state_rows);
1715
+ return build_rs (s, inp->s_copy_main , inp->s_copy_extra , state_size, n_seqs,
1716
+ kv_state->get_n_rs (), kv_state->get_head (), kv_state->get_size (), kv_state->get_rs_z (),
1717
+ get_state_rows);
1710
1718
}
1711
1719
1712
1720
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load (
@@ -1753,7 +1761,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1753
1761
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid () const {
1754
1762
const auto * mctx_cur = static_cast <const llama_memory_hybrid_context *>(mctx);
1755
1763
1756
- auto inp_rs = build_rs_inp_impl (ctx0, mctx_cur->get_recr ());
1764
+ auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr ());
1757
1765
auto inp_attn = build_attn_inp_kv_unified_impl (ctx0, ubatch, hparams, cparams, mctx_cur->get_attn ());
1758
1766
1759
1767
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move (inp_attn), std::move (inp_rs), mctx_cur);
0 commit comments