19
19
import torch_tensorrt
20
20
from transformers import AutoModelForCausalLM , AutoTokenizer
21
21
from contextlib import nullcontext
22
- from utils import export_llm , generate , recordStats , time_generate , generate_with_kv_cache
22
+ from utils import export_llm , generate , recordStats , time_generate , generate_with_kv_cache , get_zeroed_kv_cache_inputs
23
23
24
24
25
25
DEVICE = torch .device ("cuda:0" )
@@ -43,7 +43,7 @@ def get_model(args):
43
43
args .model ,
44
44
use_cache = False ,
45
45
attn_implementation = "sdpa" ,
46
- # num_hidden_layers=1
46
+ num_hidden_layers = 1
47
47
)
48
48
.eval ()
49
49
.cuda ()
@@ -194,9 +194,10 @@ def measure_perf(trt_model, input_signature, backend_name):
194
194
help = "Enable pytorch run (default: False)"
195
195
)
196
196
arg_parser .add_argument (
197
- "--kv_cache" ,
198
- action = "store_true" ,
199
- help = "Enable kv_cache (default: False)"
197
+ "--cache" ,
198
+ type = str ,
199
+ default = "static" ,
200
+ help = "Type of KV cache to use" ,
200
201
)
201
202
arg_parser .add_argument (
202
203
"--cudagraph" ,
@@ -220,9 +221,9 @@ def measure_perf(trt_model, input_signature, backend_name):
220
221
tokenizer = AutoTokenizer .from_pretrained (args .tokenizer_path )
221
222
222
223
prompt = "What is parallel programming ?"
224
+ # prompt = "What is the capital of France ?"
223
225
model_inputs = tokenizer (prompt , return_tensors = "pt" )
224
226
input_ids = model_inputs ["input_ids" ].to (DEVICE )
225
-
226
227
# Prepare input prompt
227
228
# word = "What"
228
229
# word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence
@@ -252,18 +253,67 @@ def measure_perf(trt_model, input_signature, backend_name):
252
253
)
253
254
254
255
# TRT
256
+ pyt_logits_tok1 = model .cuda ()(input_ids )
257
+ next_tokens = torch .argmax (pyt_logits_tok1 .logits [:, - 1 , :], dim = - 1 )
258
+ input_seq = torch .cat ([input_ids , next_tokens [:, None ]], dim = - 1 )
259
+ pyt_logits_tok2 = model .cuda ()(input_seq )
255
260
from lower_sdpa import *
256
- if args .kv_cache :
257
- # This import is required to register static/dynamic KV cache transformations as lowering passes
258
- from static_cache import *
261
+ if args .cache == "static" :
262
+ # This import is required to register static KV cache transformations as lowering passes
263
+ from static_cache2 import *
264
+ trt_model = compile_torchtrt (model , input_ids , args )
265
+ kv_cache = get_zeroed_kv_cache_inputs (trt_model )
266
+
267
+ # First token generation
268
+ pyt_keys = torch .load ("key.pt" ); pyt_values = torch .load ("value.pt" )
269
+ trt_logits , key_cache , value_cache , trt_keys_1 , trt_values_1 = trt_model (input_ids .clone (), True , * kv_cache , 0 , input_ids .shape [1 ])
270
+ print (f"Diff between pyt and trt logits: { torch .mean (torch .abs (pyt_logits_tok1 .logits - trt_logits ))} " )
271
+ print (f"Diff between pyt and trt keys: { torch .mean (torch .abs (pyt_keys - trt_keys_1 ))} " )
272
+ print (f"Diff between pyt and trt keys in cache: { torch .mean (torch .abs (pyt_keys - key_cache [:, :, :- 2 , :]))} " )
273
+ print (f"Diff between pyt and trt values: { torch .mean (torch .abs (pyt_values - trt_values_1 ))} " )
274
+ print (f"Diff between pyt and trt values in cache: { torch .mean (torch .abs (pyt_values - value_cache [:, :, :- 2 , :]))} " )
275
+ next_tokens = torch .argmax (trt_logits [:, - 1 , :], dim = - 1 )
276
+
277
+ # Second token generation
278
+ trt_logits_2 , key_cache2 , value_cache2 , trt_keys_2 , trt_values_2 = trt_model (next_tokens [:, None ], False , key_cache .clone (), value_cache .clone (), input_ids .shape [1 ], input_ids .shape [1 ]+ 1 )
279
+ pyt_keys2 = torch .load ("key2.pt" ); pyt_values2 = torch .load ("value2.pt" )
280
+ print (f"Diff between pyt and trt logits: { torch .mean (torch .abs (pyt_logits_tok2 .logits [:, - 1 :, :] - trt_logits_2 ))} " )
281
+ print (f"Diff between pyt and trt keys: { torch .mean (torch .abs (pyt_keys2 [:, :, - 2 :- 1 , :] - trt_keys_2 ))} " )
282
+ print (f"Diff between pyt and trt keys in cache: { torch .mean (torch .abs (pyt_keys2 - key_cache2 [:, :, :- 1 , :]))} " )
283
+ print (f"Diff between pyt and trt values: { torch .mean (torch .abs (pyt_values2 [:, :, - 2 :- 1 , :] - trt_values_2 ))} " )
284
+ print (f"Diff between pyt and trt values in cache: { torch .mean (torch .abs (pyt_values2 - value_cache2 [:, :, :- 1 , :]))} " )
285
+ breakpoint ()
286
+ elif args .cache == "dynamic" :
287
+ from dynamic_cache import *
259
288
trt_model = compile_torchtrt (model , input_ids , args )
289
+ breakpoint ()
290
+ kv_cache = get_zeroed_kv_cache_inputs (trt_model )
260
291
else :
261
292
# pyt_logits = model.cuda()(input_ids.clone())
262
293
trt_model = compile_torchtrt (model , input_ids , args )
263
294
# trt_logits = trt_model(input_ids.clone(), True)
264
295
# print(f"Diff between pyt and trt: {torch.mean(torch.abs(pyt_logits - trt_logits))}")
265
296
# print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits.logits - trt_logits.logits))}")
266
- if args .kv_cache :
297
+ if args .cache == "static" :
298
+ if args .cudagraph :
299
+ # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
300
+ # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)
301
+ torch_tensorrt .runtime .set_cudagraphs_mode (True )
302
+
303
+ trt_gen_tokens = generate_with_kv_cache (
304
+ trt_model , input_ids .clone (), MAX_OUTPUT_SEQ_LENGTH , tokenizer .eos_token_id ,
305
+ )
306
+
307
+ if args .benchmark :
308
+ trt_timings = time_generate (
309
+ generate_with_kv_cache ,
310
+ trt_model ,
311
+ input_ids .clone (),
312
+ MAX_OUTPUT_SEQ_LENGTH ,
313
+ tokenizer .eos_token_id ,
314
+ iterations = args .iterations ,
315
+ )
316
+ elif args .cache == "dynamic" :
267
317
if args .cudagraph :
268
318
# Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
269
319
# trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)
0 commit comments