1
+ // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include " helper.h"
16
+ #include " paddle/extension.h"
17
+
18
+ template <int THREADBLOCK_SIZE, bool EAGLE>
19
+ __global__ void draft_model_preprocess_kernel (
20
+ int64_t * draft_tokens,
21
+ int64_t * input_ids,
22
+ bool * stop_flags,
23
+ int * seq_lens_this_time,
24
+ int * seq_lens_encoder,
25
+ int * seq_lens_decoder,
26
+ int64_t * step_idx,
27
+ int * first_token_record,
28
+ bool * not_need_stop,
29
+ const int64_t * accept_tokens,
30
+ const int * accept_num,
31
+ const int * base_model_seq_lens_encoder,
32
+ const int * base_model_seq_lens_decoder,
33
+ const int64_t * base_model_step_idx,
34
+ const bool * base_model_stop_flags,
35
+ int64_t * base_model_draft_tokens,
36
+ const int bsz,
37
+ const int max_draft_token,
38
+ const int accept_tokens_len,
39
+ const int draft_tokens_len,
40
+ const int input_ids_len,
41
+ const int base_model_draft_tokens_len) {
42
+ typedef cub::BlockReduce<int64_t , THREADBLOCK_SIZE> BlockReduce;
43
+ __shared__ typename BlockReduce::TempStorage temp_storage;
44
+ int64_t not_stop_flag = 0 ;
45
+
46
+ int tid = threadIdx .x ;
47
+
48
+ if (tid < bsz) {
49
+ auto base_model_step_idx_now = base_model_step_idx[tid];
50
+ auto * accept_tokens_now = accept_tokens + tid * accept_tokens_len;
51
+ auto * draft_tokens_now = draft_tokens + tid * draft_tokens_len;
52
+ auto accept_num_now = accept_num[tid];
53
+ auto * input_ids_now = input_ids + tid * input_ids_len;
54
+ auto * base_model_draft_tokens_now =
55
+ base_model_draft_tokens + tid * base_model_draft_tokens_len;
56
+ #pragma unroll
57
+ for (int i = 1 ; i < base_model_draft_tokens_len; i++) {
58
+ base_model_draft_tokens_now[i] = -1 ;
59
+ }
60
+
61
+ if (!base_model_stop_flags[tid]) {
62
+ not_stop_flag = 1 ;
63
+ // 1. first token
64
+ if (base_model_step_idx_now == 0 ) {
65
+ seq_lens_this_time[tid] = 0 ;
66
+ not_stop_flag = 0 ;
67
+ } else if (base_model_step_idx_now == 1 && first_token_record[tid] > 0 ) {
68
+ // Can be extended to first few tokens
69
+ seq_lens_encoder[tid] = first_token_record[tid];
70
+ first_token_record[tid] = -1 ;
71
+ stop_flags[tid] = false ;
72
+ int64_t base_model_first_token = accept_tokens_now[0 ];
73
+ int position = base_model_seq_lens_decoder[tid];
74
+ if (EAGLE) {
75
+ input_ids_now[position - 1 ] = base_model_first_token;
76
+ seq_lens_this_time[tid] = base_model_seq_lens_decoder[tid];
77
+ } else {
78
+ input_ids_now[position] = base_model_first_token;
79
+ seq_lens_this_time[tid] = base_model_seq_lens_decoder[tid] + 1 ;
80
+ }
81
+ } else if (accept_num_now <=
82
+ max_draft_token) /* Accept partial draft tokens*/ {
83
+ // Base Model reject stop
84
+ if (stop_flags[tid]) {
85
+ stop_flags[tid] = false ;
86
+ seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
87
+ step_idx[tid] = base_model_step_idx[tid];
88
+ } else {
89
+ seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
90
+ step_idx[tid] -= max_draft_token - accept_num_now;
91
+ }
92
+ int64_t modified_token = accept_tokens_now[accept_num_now - 1 ];
93
+ draft_tokens_now[0 ] = modified_token;
94
+ seq_lens_this_time[tid] = 1 ;
95
+
96
+ } else /* Accept all draft tokens*/ {
97
+ draft_tokens_now[1 ] = accept_tokens_now[max_draft_token];
98
+ seq_lens_this_time[tid] = 2 ;
99
+ }
100
+ } else {
101
+ stop_flags[tid] = true ;
102
+ seq_lens_this_time[tid] = 0 ;
103
+ seq_lens_decoder[tid] = 0 ;
104
+ }
105
+ }
106
+ __syncthreads ();
107
+ int64_t not_stop_flag_sum = BlockReduce (temp_storage).Sum (not_stop_flag);
108
+ if (tid == 0 ) {
109
+ not_need_stop[0 ] = not_stop_flag_sum > 0 ;
110
+ }
111
+ }
112
+
113
+
114
+ void DraftModelPreprocess (const paddle::Tensor& draft_tokens,
115
+ const paddle::Tensor& input_ids,
116
+ const paddle::Tensor& stop_flags,
117
+ const paddle::Tensor& seq_lens_this_time,
118
+ const paddle::Tensor& seq_lens_encoder,
119
+ const paddle::Tensor& seq_lens_decoder,
120
+ const paddle::Tensor& step_idx,
121
+ const paddle::Tensor& first_token_record,
122
+ const paddle::Tensor& not_need_stop,
123
+ const paddle::Tensor& accept_tokens,
124
+ const paddle::Tensor& accept_num,
125
+ const paddle::Tensor& base_model_seq_lens_encoder,
126
+ const paddle::Tensor& base_model_seq_lens_decoder,
127
+ const paddle::Tensor& base_model_step_idx,
128
+ const paddle::Tensor& base_model_stop_flags,
129
+ const paddle::Tensor& base_model_draft_tokens,
130
+ const int max_draft_token,
131
+ const std::string& draft_type) {
132
+ int real_bsz = seq_lens_this_time.shape ()[0 ];
133
+ int accept_tokens_len = accept_tokens.shape ()[1 ];
134
+ int input_ids_len = input_ids.shape ()[1 ];
135
+ int draft_tokens_len = draft_tokens.shape ()[1 ];
136
+ auto cu_stream = seq_lens_this_time.stream ();
137
+ constexpr int BlockSize = 256 ;
138
+ int base_model_draft_tokens_len = base_model_draft_tokens.shape ()[1 ];
139
+ auto not_need_stop_gpu =
140
+ not_need_stop.copy_to (seq_lens_this_time.place (), false );
141
+
142
+
143
+ if (draft_type == " eagle" ) {
144
+ draft_model_preprocess_kernel<BlockSize, true >
145
+ <<<1 , BlockSize, 0 , cu_stream>>> (
146
+ const_cast <int64_t *>(draft_tokens.data <int64_t >()),
147
+ const_cast <int64_t *>(input_ids.data <int64_t >()),
148
+ const_cast <bool *>(stop_flags.data <bool >()),
149
+ const_cast <int *>(seq_lens_this_time.data <int >()),
150
+ const_cast <int *>(seq_lens_encoder.data <int >()),
151
+ const_cast <int *>(seq_lens_decoder.data <int >()),
152
+ const_cast <int64_t *>(step_idx.data <int64_t >()),
153
+ const_cast <int *>(first_token_record.data <int >()),
154
+ const_cast <bool *>(not_need_stop_gpu.data <bool >()),
155
+ accept_tokens.data <int64_t >(),
156
+ accept_num.data <int >(),
157
+ base_model_seq_lens_encoder.data <int >(),
158
+ base_model_seq_lens_decoder.data <int >(),
159
+ base_model_step_idx.data <int64_t >(),
160
+ base_model_stop_flags.data <bool >(),
161
+ const_cast <int64_t *>(base_model_draft_tokens.data <int64_t >()),
162
+ real_bsz,
163
+ max_draft_token,
164
+ accept_tokens_len,
165
+ draft_tokens_len,
166
+ input_ids_len,
167
+ base_model_draft_tokens_len);
168
+ } else {
169
+ draft_model_preprocess_kernel<BlockSize, false >
170
+ <<<1 , BlockSize, 0 , cu_stream>>> (
171
+ const_cast <int64_t *>(draft_tokens.data <int64_t >()),
172
+ const_cast <int64_t *>(input_ids.data <int64_t >()),
173
+ const_cast <bool *>(stop_flags.data <bool >()),
174
+ const_cast <int *>(seq_lens_this_time.data <int >()),
175
+ const_cast <int *>(seq_lens_encoder.data <int >()),
176
+ const_cast <int *>(seq_lens_decoder.data <int >()),
177
+ const_cast <int64_t *>(step_idx.data <int64_t >()),
178
+ const_cast <int *>(first_token_record.data <int >()),
179
+ const_cast <bool *>(not_need_stop_gpu.data <bool >()),
180
+ accept_tokens.data <int64_t >(),
181
+ accept_num.data <int >(),
182
+ base_model_seq_lens_encoder.data <int >(),
183
+ base_model_seq_lens_decoder.data <int >(),
184
+ base_model_step_idx.data <int64_t >(),
185
+ base_model_stop_flags.data <bool >(),
186
+ const_cast <int64_t *>(base_model_draft_tokens.data <int64_t >()),
187
+ real_bsz,
188
+ max_draft_token,
189
+ accept_tokens_len,
190
+ draft_tokens_len,
191
+ input_ids_len,
192
+ base_model_draft_tokens_len);
193
+ }
194
+
195
+
196
+ auto not_need_stop_cpu =
197
+ not_need_stop_gpu.copy_to (not_need_stop.place (), false );
198
+ bool * not_need_stop_data = const_cast <bool *>(not_need_stop.data <bool >());
199
+ not_need_stop_data[0 ] = not_need_stop_cpu.data <bool >()[0 ];
200
+ }
201
+
202
+
203
+ PD_BUILD_OP (draft_model_preprocess)
204
+ .Inputs({" draft_tokens" ,
205
+ " input_ids" ,
206
+ " stop_flags" ,
207
+ " seq_lens_this_time" ,
208
+ " seq_lens_encoder" ,
209
+ " seq_lens_decoder" ,
210
+ " step_idx" ,
211
+ " first_token_record" ,
212
+ " not_need_stop" ,
213
+ " accept_tokens" ,
214
+ " accept_num" ,
215
+ " base_model_seq_lens_encoder" ,
216
+ " base_model_seq_lens_decoder" ,
217
+ " base_model_step_idx" ,
218
+ " base_model_stop_flags" ,
219
+ " base_model_draft_tokens" })
220
+ .Outputs({" draft_tokens_out" ,
221
+ " input_ids_out" ,
222
+ " stop_flags_out" ,
223
+ " seq_lens_this_time_out" ,
224
+ " seq_lens_encoder_out" ,
225
+ " seq_lens_decoder_out" ,
226
+ " step_idx_out" ,
227
+ " not_need_stop_out" ,
228
+ " first_token_record_out" })
229
+ .Attrs({" max_draft_token: int" , " draft_type: std::string" })
230
+ .SetInplaceMap({{" draft_tokens" , " draft_tokens_out" },
231
+ {" input_ids" , " input_ids_out" },
232
+ {" stop_flags" , " stop_flags_out" },
233
+ {" seq_lens_this_time" , " seq_lens_this_time_out" },
234
+ {" seq_lens_encoder" , " seq_lens_encoder_out" },
235
+ {" seq_lens_decoder" , " seq_lens_decoder_out" },
236
+ {" step_idx" , " step_idx_out" },
237
+ {" not_need_stop" , " not_need_stop_out" },
238
+ {" first_token_record" , " first_token_record_out" }})
239
+ .SetKernelFn(PD_KERNEL(DraftModelPreprocess));
0 commit comments