Skip to content

Commit bb103a3

Browse files
authored
[Inference] Support eagle for llama (#9812)
1 parent bb0c9ad commit bb103a3

14 files changed

+1710
-49
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/extension.h"
16+
17+
18+
__global__ void draft_model_update_seq_lens_this_time_kernel(
19+
const int64_t* base_model_draft_tokens,
20+
int* base_model_seq_lens_this_time,
21+
const int* base_model_seq_lens_encoder,
22+
const bool* base_model_stop_flags,
23+
int bsz,
24+
int base_model_draft_token_len) {
25+
int tid = threadIdx.x;
26+
if (tid < bsz) {
27+
if (!base_model_stop_flags[tid] && base_model_seq_lens_encoder[tid] == 0) {
28+
const int64_t* base_model_draft_tokens_now =
29+
base_model_draft_tokens + tid * base_model_draft_token_len;
30+
int token_num = 0;
31+
32+
for (int i = 0; i < base_model_draft_token_len; ++i) {
33+
if (base_model_draft_tokens_now[i] != -1) {
34+
token_num++;
35+
}
36+
}
37+
base_model_seq_lens_this_time[tid] = token_num;
38+
} else if (base_model_stop_flags[tid]) {
39+
base_model_seq_lens_this_time[tid] = 0;
40+
}
41+
}
42+
}
43+
44+
45+
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
46+
const paddle::Tensor& base_model_seq_lens_this_time,
47+
const paddle::Tensor& base_model_seq_lens_encoder,
48+
const paddle::Tensor& base_model_stop_flags) {
49+
int real_bsz = base_model_seq_lens_this_time.shape()[0];
50+
auto cu_stream = base_model_seq_lens_this_time.stream();
51+
constexpr int BlockSize = 512;
52+
int base_model_draft_token_len = base_model_draft_tokens.shape()[1];
53+
draft_model_update_seq_lens_this_time_kernel<<<1, BlockSize, 0, cu_stream>>>(
54+
base_model_draft_tokens.data<int64_t>(),
55+
const_cast<int*>(base_model_seq_lens_this_time.data<int>()),
56+
base_model_seq_lens_encoder.data<int>(),
57+
base_model_stop_flags.data<bool>(),
58+
real_bsz,
59+
base_model_draft_token_len);
60+
}
61+
62+
63+
PD_BUILD_OP(draft_model_postprocess)
64+
.Inputs({"base_model_draft_tokens",
65+
"base_model_seq_lens_this_time",
66+
"base_model_seq_lens_encoder",
67+
"base_model_stop_flags"})
68+
.Outputs({"base_model_draft_tokens_out",
69+
"base_model_seq_lens_this_time_out",
70+
"base_model_stop_flags_out"})
71+
.SetInplaceMap({{"base_model_draft_tokens", "base_model_draft_tokens_out"},
72+
{"base_model_seq_lens_this_time",
73+
"base_model_seq_lens_this_time_out"},
74+
{"base_model_stop_flags", "base_model_stop_flags_out"}})
75+
.SetKernelFn(PD_KERNEL(DraftModelPostprocess));
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "helper.h"
16+
#include "paddle/extension.h"
17+
18+
template <int THREADBLOCK_SIZE, bool EAGLE>
19+
__global__ void draft_model_preprocess_kernel(
20+
int64_t* draft_tokens,
21+
int64_t* input_ids,
22+
bool* stop_flags,
23+
int* seq_lens_this_time,
24+
int* seq_lens_encoder,
25+
int* seq_lens_decoder,
26+
int64_t* step_idx,
27+
int* first_token_record,
28+
bool* not_need_stop,
29+
const int64_t* accept_tokens,
30+
const int* accept_num,
31+
const int* base_model_seq_lens_encoder,
32+
const int* base_model_seq_lens_decoder,
33+
const int64_t* base_model_step_idx,
34+
const bool* base_model_stop_flags,
35+
int64_t* base_model_draft_tokens,
36+
const int bsz,
37+
const int max_draft_token,
38+
const int accept_tokens_len,
39+
const int draft_tokens_len,
40+
const int input_ids_len,
41+
const int base_model_draft_tokens_len) {
42+
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
43+
__shared__ typename BlockReduce::TempStorage temp_storage;
44+
int64_t not_stop_flag = 0;
45+
46+
int tid = threadIdx.x;
47+
48+
if (tid < bsz) {
49+
auto base_model_step_idx_now = base_model_step_idx[tid];
50+
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
51+
auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len;
52+
auto accept_num_now = accept_num[tid];
53+
auto* input_ids_now = input_ids + tid * input_ids_len;
54+
auto* base_model_draft_tokens_now =
55+
base_model_draft_tokens + tid * base_model_draft_tokens_len;
56+
#pragma unroll
57+
for (int i = 1; i < base_model_draft_tokens_len; i++) {
58+
base_model_draft_tokens_now[i] = -1;
59+
}
60+
61+
if (!base_model_stop_flags[tid]) {
62+
not_stop_flag = 1;
63+
// 1. first token
64+
if (base_model_step_idx_now == 0) {
65+
seq_lens_this_time[tid] = 0;
66+
not_stop_flag = 0;
67+
} else if (base_model_step_idx_now == 1 && first_token_record[tid] > 0) {
68+
// Can be extended to first few tokens
69+
seq_lens_encoder[tid] = first_token_record[tid];
70+
first_token_record[tid] = -1;
71+
stop_flags[tid] = false;
72+
int64_t base_model_first_token = accept_tokens_now[0];
73+
int position = base_model_seq_lens_decoder[tid];
74+
if (EAGLE) {
75+
input_ids_now[position - 1] = base_model_first_token;
76+
seq_lens_this_time[tid] = base_model_seq_lens_decoder[tid];
77+
} else {
78+
input_ids_now[position] = base_model_first_token;
79+
seq_lens_this_time[tid] = base_model_seq_lens_decoder[tid] + 1;
80+
}
81+
} else if (accept_num_now <=
82+
max_draft_token) /*Accept partial draft tokens*/ {
83+
// Base Model reject stop
84+
if (stop_flags[tid]) {
85+
stop_flags[tid] = false;
86+
seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
87+
step_idx[tid] = base_model_step_idx[tid];
88+
} else {
89+
seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
90+
step_idx[tid] -= max_draft_token - accept_num_now;
91+
}
92+
int64_t modified_token = accept_tokens_now[accept_num_now - 1];
93+
draft_tokens_now[0] = modified_token;
94+
seq_lens_this_time[tid] = 1;
95+
96+
} else /*Accept all draft tokens*/ {
97+
draft_tokens_now[1] = accept_tokens_now[max_draft_token];
98+
seq_lens_this_time[tid] = 2;
99+
}
100+
} else {
101+
stop_flags[tid] = true;
102+
seq_lens_this_time[tid] = 0;
103+
seq_lens_decoder[tid] = 0;
104+
}
105+
}
106+
__syncthreads();
107+
int64_t not_stop_flag_sum = BlockReduce(temp_storage).Sum(not_stop_flag);
108+
if (tid == 0) {
109+
not_need_stop[0] = not_stop_flag_sum > 0;
110+
}
111+
}
112+
113+
114+
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
115+
const paddle::Tensor& input_ids,
116+
const paddle::Tensor& stop_flags,
117+
const paddle::Tensor& seq_lens_this_time,
118+
const paddle::Tensor& seq_lens_encoder,
119+
const paddle::Tensor& seq_lens_decoder,
120+
const paddle::Tensor& step_idx,
121+
const paddle::Tensor& first_token_record,
122+
const paddle::Tensor& not_need_stop,
123+
const paddle::Tensor& accept_tokens,
124+
const paddle::Tensor& accept_num,
125+
const paddle::Tensor& base_model_seq_lens_encoder,
126+
const paddle::Tensor& base_model_seq_lens_decoder,
127+
const paddle::Tensor& base_model_step_idx,
128+
const paddle::Tensor& base_model_stop_flags,
129+
const paddle::Tensor& base_model_draft_tokens,
130+
const int max_draft_token,
131+
const std::string& draft_type) {
132+
int real_bsz = seq_lens_this_time.shape()[0];
133+
int accept_tokens_len = accept_tokens.shape()[1];
134+
int input_ids_len = input_ids.shape()[1];
135+
int draft_tokens_len = draft_tokens.shape()[1];
136+
auto cu_stream = seq_lens_this_time.stream();
137+
constexpr int BlockSize = 256;
138+
int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1];
139+
auto not_need_stop_gpu =
140+
not_need_stop.copy_to(seq_lens_this_time.place(), false);
141+
142+
143+
if (draft_type == "eagle") {
144+
draft_model_preprocess_kernel<BlockSize, true>
145+
<<<1, BlockSize, 0, cu_stream>>>(
146+
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
147+
const_cast<int64_t*>(input_ids.data<int64_t>()),
148+
const_cast<bool*>(stop_flags.data<bool>()),
149+
const_cast<int*>(seq_lens_this_time.data<int>()),
150+
const_cast<int*>(seq_lens_encoder.data<int>()),
151+
const_cast<int*>(seq_lens_decoder.data<int>()),
152+
const_cast<int64_t*>(step_idx.data<int64_t>()),
153+
const_cast<int*>(first_token_record.data<int>()),
154+
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
155+
accept_tokens.data<int64_t>(),
156+
accept_num.data<int>(),
157+
base_model_seq_lens_encoder.data<int>(),
158+
base_model_seq_lens_decoder.data<int>(),
159+
base_model_step_idx.data<int64_t>(),
160+
base_model_stop_flags.data<bool>(),
161+
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
162+
real_bsz,
163+
max_draft_token,
164+
accept_tokens_len,
165+
draft_tokens_len,
166+
input_ids_len,
167+
base_model_draft_tokens_len);
168+
} else {
169+
draft_model_preprocess_kernel<BlockSize, false>
170+
<<<1, BlockSize, 0, cu_stream>>>(
171+
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
172+
const_cast<int64_t*>(input_ids.data<int64_t>()),
173+
const_cast<bool*>(stop_flags.data<bool>()),
174+
const_cast<int*>(seq_lens_this_time.data<int>()),
175+
const_cast<int*>(seq_lens_encoder.data<int>()),
176+
const_cast<int*>(seq_lens_decoder.data<int>()),
177+
const_cast<int64_t*>(step_idx.data<int64_t>()),
178+
const_cast<int*>(first_token_record.data<int>()),
179+
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
180+
accept_tokens.data<int64_t>(),
181+
accept_num.data<int>(),
182+
base_model_seq_lens_encoder.data<int>(),
183+
base_model_seq_lens_decoder.data<int>(),
184+
base_model_step_idx.data<int64_t>(),
185+
base_model_stop_flags.data<bool>(),
186+
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
187+
real_bsz,
188+
max_draft_token,
189+
accept_tokens_len,
190+
draft_tokens_len,
191+
input_ids_len,
192+
base_model_draft_tokens_len);
193+
}
194+
195+
196+
auto not_need_stop_cpu =
197+
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
198+
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
199+
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
200+
}
201+
202+
203+
PD_BUILD_OP(draft_model_preprocess)
204+
.Inputs({"draft_tokens",
205+
"input_ids",
206+
"stop_flags",
207+
"seq_lens_this_time",
208+
"seq_lens_encoder",
209+
"seq_lens_decoder",
210+
"step_idx",
211+
"first_token_record",
212+
"not_need_stop",
213+
"accept_tokens",
214+
"accept_num",
215+
"base_model_seq_lens_encoder",
216+
"base_model_seq_lens_decoder",
217+
"base_model_step_idx",
218+
"base_model_stop_flags",
219+
"base_model_draft_tokens"})
220+
.Outputs({"draft_tokens_out",
221+
"input_ids_out",
222+
"stop_flags_out",
223+
"seq_lens_this_time_out",
224+
"seq_lens_encoder_out",
225+
"seq_lens_decoder_out",
226+
"step_idx_out",
227+
"not_need_stop_out",
228+
"first_token_record_out"})
229+
.Attrs({"max_draft_token: int", "draft_type: std::string"})
230+
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
231+
{"input_ids", "input_ids_out"},
232+
{"stop_flags", "stop_flags_out"},
233+
{"seq_lens_this_time", "seq_lens_this_time_out"},
234+
{"seq_lens_encoder", "seq_lens_encoder_out"},
235+
{"seq_lens_decoder", "seq_lens_decoder_out"},
236+
{"step_idx", "step_idx_out"},
237+
{"not_need_stop", "not_need_stop_out"},
238+
{"first_token_record", "first_token_record_out"}})
239+
.SetKernelFn(PD_KERNEL(DraftModelPreprocess));
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "helper.h"
16+
17+
18+
__global__ void update_pre_ids_kernel(const int64_t* draft_tokens,
19+
int64_t* pre_ids_all,
20+
const bool* stop_flags,
21+
int* seq_lens_this_time,
22+
const int64_t* step_idx,
23+
int bs,
24+
int pre_id_length,
25+
int max_draft_token) {
26+
int tid = threadIdx.x;
27+
if (tid < bs && seq_lens_this_time[tid] != 0 && !stop_flags[tid]) {
28+
int64_t* pre_ids_all_now = pre_ids_all + tid * pre_id_length;
29+
const int64_t* draft_token_now = draft_tokens + tid * max_draft_token;
30+
const int seq_len_this_time = seq_lens_this_time[tid];
31+
if (step_idx[tid] - 1 > 0 /*Decoder Step*/) {
32+
for (int i = 0; i < seq_len_this_time; ++i) {
33+
pre_ids_all_now[step_idx[tid] - i] =
34+
draft_token_now[seq_len_this_time - 1 - i];
35+
}
36+
} else if (step_idx[tid] == 1 /*Encoder Step*/) {
37+
pre_ids_all_now[1] = draft_token_now[0];
38+
}
39+
seq_lens_this_time[tid] = 1;
40+
}
41+
}
42+
43+
44+
void SpeculateDraftModelUpdate(const paddle::Tensor& draft_tokens,
45+
const paddle::Tensor& pre_ids_all,
46+
const paddle::Tensor& stop_flags,
47+
const paddle::Tensor& seq_lens_this_time,
48+
const paddle::Tensor& seq_lens_encoder,
49+
const paddle::Tensor& seq_lens_decoder,
50+
const paddle::Tensor& step_idx) {
51+
int64_t real_bs = seq_lens_this_time.shape()[0];
52+
int64_t pre_id_length = pre_ids_all.shape()[1];
53+
auto cu_stream = seq_lens_this_time.stream();
54+
int64_t max_draft_token = draft_tokens.shape()[1];
55+
56+
int block_size = (real_bs + 32 - 1) / 32 * 32;
57+
update_pre_ids_kernel<<<1, block_size, 0, cu_stream>>>(
58+
draft_tokens.data<int64_t>(),
59+
const_cast<int64_t*>(pre_ids_all.data<int64_t>()),
60+
stop_flags.data<bool>(),
61+
const_cast<int*>(seq_lens_this_time.data<int>()),
62+
step_idx.data<int64_t>(),
63+
real_bs,
64+
pre_id_length,
65+
max_draft_token);
66+
}
67+
68+
PD_BUILD_OP(draft_model_set_value_by_flags)
69+
.Inputs({"draft_tokens",
70+
"pre_ids_all",
71+
"stop_flags",
72+
"seq_lens_this_time",
73+
"seq_lens_encoder",
74+
"seq_lens_decoder",
75+
"step_idx"})
76+
.Outputs({"pre_ids_all_out"})
77+
.SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}})
78+
.SetKernelFn(PD_KERNEL(SpeculateDraftModelUpdate));

0 commit comments

Comments
 (0)