diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py new file mode 100644 index 0000000000..be91ac6efe --- /dev/null +++ b/tools/llm/run_vlm.py @@ -0,0 +1,615 @@ +""" +.. _run_vlm: + +Benchmarking VLM Inference with Torch-TensorRT +========================================================== + +This script provides a framework for benchmarking the performance of Visual-Language +Models (VLMs). It optimizes the two most computationally intensive components of a +VLM—the language model and the vision model (image feature extraction)—using +the Torch-TensorRT dynamo backend. + +Key Features: +- **Component-wise Optimization**: Compiles both the language and vision models + separately with Torch-TensorRT to accelerate inference. +- **Performance Benchmarking**: Runs the model for multiple iterations to + measure and compare inference latency against the PyTorch baseline. +- **Output Verification**: Checks for token-level consistency between the optimized + TensorRT model and the original PyTorch model to ensure correctness. +- **KV Cache Testing**: Includes options to test inference with and without + KV caching to evaluate its impact on performance. + +This tool mirrors the style and structure of `run_llm.py`, providing a clear +workflow for VLM optimization and analysis. +""" + +import argparse +import copy +import os +import sys +from contextlib import nullcontext +from typing import Tuple + +import requests +import torch +import torch_tensorrt +from PIL import Image +from torchtrt_ext import register_sdpa +from transformers import AutoConfig, AutoModel, AutoProcessor +from utils import ( + export_llm, + generate_mm, + generate_mm_qwen2_5_vl, + generate_mm_qwen2_5_vl_with_static_cache, + generate_mm_with_static_cache, + record_stats, + time_generate_mm, +) + +# -----------------------------------------------------------------------------# +# Global configuration +# -----------------------------------------------------------------------------# +DEVICE = torch.device("cuda:0") + +# --- WORKAROUND FOR EAGLE2 SDPA COMPILATION --- +# Eagle2's language model (Qwen2) implicitly defaults to "flash_attention_2" +# due to settings in its remote code and config.json. This prevents direct +# compilation with SDPA. To work around this without modifying the library, + +# we "monkey-patch" the global attention function map for Qwen2. +# This ensures that any part of the code (including torch.export) requesting +# "flash_attention_2" will receive the "sdpa" implementation instead. +# This patch is global for the script's execution context. +import transformers.models.qwen2.modeling_qwen2 as mq + +mq.ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = mq.ALL_ATTENTION_FUNCTIONS["sdpa"] +# --- END WORKAROUND --- + +# --- Model-specific constants for benchmark and compilation --- +# Centralizing these values improves readability and maintainability. +MODEL_CONSTANTS = { + "nvidia/Eagle2-2B": { + "EXAMPLE_SEQLEN": 2560, # A fixed sequence length for creating the example tensor for TRT compilation. + "IMAGE_TOKENS": 1792, # Number of special tokens used to represent the image patch embeddings in the input sequence for Eagle2-2B VLM. + "PROMPT_WRAPPER_TOKENS": 26, # The number of special/processing tokens added by the processor's chat template in benchmark mode. + }, + "Qwen/Qwen2.5-VL-3B-Instruct": { + "EXAMPLE_SEQLEN": 2560, + "IMAGE_TOKENS": 1426, + "PROMPT_WRAPPER_TOKENS": 21, + }, +} +# --- END Model-specific constants --- + +# -----------------------------------------------------------------------------# +# Model loading helpers +# -----------------------------------------------------------------------------# + + +def _load_eagle2(device: torch.device, torch_dtype: torch.dtype): + """ + Load nvidia/Eagle2-2B model and processor, ensuring the language model uses SDPA. + + Returns + ------- + tuple[torch.nn.Module, transformers.AutoProcessor, torch.nn.Embedding] + The model, its processor and the language-model input embedding layer. + """ + model_id = "nvidia/Eagle2-2B" + with torch.no_grad(): + model = ( + AutoModel.from_pretrained( + model_id, + trust_remote_code=True, + torch_dtype=torch_dtype, + # attn_implementation="sdpa" is ignored due to the model's remote code. + ) + .eval() + .to(device) + ) + + processor = AutoProcessor.from_pretrained( + model_id, trust_remote_code=True, use_fast=True + ) + if hasattr(processor, "tokenizer"): + processor.tokenizer.padding_side = "left" + + emb_layer = model.language_model.get_input_embeddings().to(torch_dtype).to(device) + return model, processor, emb_layer + + +def _load_qwen2_5_vl(device, torch_dtype: torch.dtype): + """ + Load Qwen2.5-VL model and processor. + """ + from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + model_id = "Qwen/Qwen2.5-VL-3B-Instruct" + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_id, torch_dtype=torch_dtype, device_map=device + ).eval() + processor = AutoProcessor.from_pretrained(model_id) + emb_layer = model.model.get_input_embeddings().to(torch_dtype).to(device) + return model, processor, emb_layer + + +def load_model( + model_name: str, device: torch.device, torch_dtype: torch.dtype +) -> Tuple[torch.nn.Module, AutoProcessor, torch.nn.Embedding]: + """Dispatch helper for supported VLMs.""" + if model_name == "nvidia/Eagle2-2B": + return _load_eagle2(device, torch_dtype) + elif model_name == "Qwen/Qwen2.5-VL-3B-Instruct": + return _load_qwen2_5_vl(device, torch_dtype) + msg = f"Unsupported model: '{model_name}'. Supported models are: ['nvidia/Eagle2-2B', 'Qwen/Qwen2.5-VL-3B-Instruct']" + raise ValueError(msg) + + +# -----------------------------------------------------------------------------# +# Input loading helpers +# -----------------------------------------------------------------------------# + + +def load_inputs(args: argparse.Namespace, processor, device: torch.device): + """ + Loads and constructs the input dictionary for the specified VLM model. + """ + url = "https://www.ilankelman.org/stopsigns/australia.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + if args.benchmark: + model_constants = MODEL_CONSTANTS[args.model] + image_tokens = model_constants["IMAGE_TOKENS"] + wrapper_tokens = model_constants["PROMPT_WRAPPER_TOKENS"] + + prompt_len = args.isl - image_tokens - wrapper_tokens + prompt_txt = " ".join(["token"] * max(prompt_len, 0)) + else: + prompt_txt = args.prompt or "Describe this image." + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": prompt_txt}, + ], + } + ] + + text = [ + processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + ] + + # --- Model-specific vision processing --- + if "qwen" in args.model.lower(): + from qwen_vl_utils import process_vision_info + + image_inputs, video_inputs = process_vision_info(messages) + else: # eagle2 + image_inputs, video_inputs = processor.process_vision_info(messages) + + inputs = processor( + text=text, + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ).to(device) + + return inputs + + +# -----------------------------------------------------------------------------# +# Torch-TensorRT compilation helpers +# -----------------------------------------------------------------------------# + + +class _LMNoCache(torch.nn.Module): + """ + Thin wrapper that exposes a language model via ``inputs_embeds`` without KV-cache. + """ + + def __init__(self, lm): + super().__init__() + self.lm = lm + + def forward(self, inputs_embeds, position_ids): + out = self.lm(inputs_embeds=inputs_embeds, position_ids=position_ids) + return ( + out.logits + if hasattr(out, "logits") + else out.last_hidden_state if hasattr(out, "last_hidden_state") else out + ) + + +def _compile_lm( + language_model: torch.nn.Module, + input_embeds: torch.Tensor, + args: argparse.Namespace, +) -> torch.nn.Module: + """ + Compile the language model component of a VLM with Torch-TensorRT + """ + lm_wrap = _LMNoCache(language_model).to(DEVICE).eval() + max_seq_len = input_embeds.shape[1] + args.num_tokens + + seq_len = torch.export.Dim("seq", min=1, max=max_seq_len) + position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(DEVICE) + + dyn_shapes = {"inputs_embeds": {1: seq_len}, "position_ids": {1: seq_len}} + + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + else: # FP32 + enabled_precisions = {torch.float32} + + exported_program = export_llm( + lm_wrap, input_embeds, min_seq_len=1, max_seq_len=2560 + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_mod = torch_tensorrt.dynamo.compile( + exported_program, + inputs=[input_embeds, position_ids], + enabled_precisions=enabled_precisions, + use_explicit_typing=use_explicit_typing, + use_fp32_acc=use_fp32_acc, + device=DEVICE, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + offload_module_to_cpu=True, + min_block_size=args.min_block_size, + ) + return trt_mod + + +def compile_lm_torchtrt( + model: torch.nn.Module, args: argparse.Namespace +) -> torch.nn.Module: + """ + Compiles the Language Model (LLM) component of the VLM using Torch-TensorRT. + """ + torch_dtype = { + "FP16": torch.float16, + "BF16": torch.bfloat16, + }.get(args.precision, torch.float32) + + lm_model = model.model if "qwen" in args.model.lower() else model.language_model + + model_constants = MODEL_CONSTANTS.get( + args.model, {"EXAMPLE_SEQLEN": args.num_tokens} + ) + example_seq_len = model_constants["EXAMPLE_SEQLEN"] + + example_embeds = torch.randn( + args.batch_size, + example_seq_len, + lm_model.config.hidden_size, + dtype=torch_dtype, + device=DEVICE, + ) + + # All supported models use the same compilation helper. + if args.model in ["nvidia/Eagle2-2B", "Qwen/Qwen2.5-VL-3B-Instruct"]: + return _compile_lm(lm_model, example_embeds, args) + else: + msg = f"Unsupported model: '{args.model}'. Supported models are: ['nvidia/Eagle2-2B', 'Qwen/Qwen2.5-VL-3B-Instruct']" + raise ValueError(msg) + + +def _compile_eagle2_vision( + vision_model: torch.nn.Module, + example_pixel_values: torch.Tensor, + args: argparse.Namespace, +) -> torch.nn.Module: + """ + Compile Eagle2 vision model with Torch-TensorRT. + """ + # Set precision-specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + else: # FP32 + enabled_precisions = {torch.float32} + + with torch.inference_mode(): + exported_program = torch.export.export( + vision_model, + (example_pixel_values,), + strict=False, + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_mod = torch_tensorrt.dynamo.compile( + exported_program, + inputs=[example_pixel_values], + enabled_precisions=enabled_precisions, + use_explicit_typing=use_explicit_typing, + use_fp32_acc=use_fp32_acc, + device=DEVICE, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + offload_module_to_cpu=True, + min_block_size=args.min_block_size, + ) + return trt_mod + + +def compile_vision_torchtrt( + model: torch.nn.Module, + args: argparse.Namespace, + example_pixel_values: torch.Tensor, +) -> torch.nn.Module: + """ + Dispatcher function for vision model compilation. + """ + if args.model == "nvidia/Eagle2-2B": + return _compile_eagle2_vision(model.vision_model, example_pixel_values, args) + elif args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + # TODO: Vision model compilation for Qwen2.5-VL is currently skipped. + # The model's `get_window_index` method uses dynamic Python list operations + # (e.g., .tolist(), .extend()) to process variable-sized image grids for + # windowed attention. These operations are incompatible with torch.export's + # static graph tracing, preventing successful compilation. + return model.visual + else: + raise ValueError(f"Unsupported model: {args.model}") + + +# -----------------------------------------------------------------------------# +# Utility helpers +# -----------------------------------------------------------------------------# + + +def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): + """Pretty-print generated text for comparison.""" + print(f"========= {backend_name} =========") + print( + f"{backend_name} model generated text: ", + tokenizer.decode(gen_tokens[0], skip_special_tokens=True), + ) + print("===================================") + + +# -----------------------------------------------------------------------------# +# Main driver +# -----------------------------------------------------------------------------# +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run VLM inference (PyTorch & TensorRT back-ends)" + ) + parser.add_argument( + "--model", + default="Qwen/Qwen2.5-VL-3B-Instruct", + choices=["nvidia/Eagle2-2B", "Qwen/Qwen2.5-VL-3B-Instruct"], + help="VLM model name", + ) + parser.add_argument("--prompt", default="Describe this image.", help="Prompt text") + parser.add_argument( + "--precision", + default="FP16", + choices=["FP16", "BF16", "FP32"], + help="Computation precision", + ) + parser.add_argument("--iterations", type=int, default=5, help="# iterations") + parser.add_argument("--min_block_size", type=int, default=1, help="Min block size") + parser.add_argument("--num_tokens", type=int, default=128, help="# new tokens") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--isl", type=int, default=2048, help="Input seq length") + parser.add_argument( + "--enable_pytorch_run", + action="store_true", + help="Run the PyTorch baseline as well", + ) + parser.add_argument( + "--cache", + default="", + choices=["", "static_v1"], + help="KV-cache variant to use", + ) + parser.add_argument( + "--debug", action="store_true", help="Enable Torch-TensorRT debug logs" + ) + parser.add_argument( + "--benchmark", action="store_true", help="Enable benchmarking mode" + ) + + args = parser.parse_args() + + # -------------------------------------------------------------------------# + # 1. Model / processor / embeddings + # -------------------------------------------------------------------------# + dtype = { + "FP16": torch.float16, + "BF16": torch.bfloat16, + }.get(args.precision, torch.float32) + + model, processor, emb_layer = load_model(args.model, DEVICE, dtype) + + # -------------------------------------------------------------------------# + # 2. Input construction (image + text prompt) + # -------------------------------------------------------------------------# + inputs = load_inputs(args, processor, DEVICE) + + max_output_len = inputs["input_ids"].shape[1] + args.num_tokens + + # -------------------------------------------------------------------------# + # 3. Optional: PyTorch baseline + # -------------------------------------------------------------------------# + pyt_gen_tokens = pyt_timings = pyt_stats = None + if args.enable_pytorch_run: + if "qwen" in args.model.lower(): + pyt_gen_tokens = generate_mm_qwen2_5_vl( + model, + inputs["pixel_values"], + inputs["input_ids"], + inputs["image_grid_thw"], + max_output_len, + processor.tokenizer.eos_token_id, + emb_layer, + ) + else: # eagle2 + pyt_gen_tokens = generate_mm( + model, + inputs["pixel_values"], + inputs["input_ids"], + max_output_len, + processor.tokenizer.eos_token_id, + emb_layer, + ) + print_outputs("PyTorch", pyt_gen_tokens, processor.tokenizer) + if args.benchmark: + # Prepare args for the timing function + time_generate_args = { + "model": model, + "pixel_values": inputs["pixel_values"].clone(), + "input_ids": inputs["input_ids"].clone(), + "max_output_seq_length": max_output_len, + "eos_token_id": processor.tokenizer.eos_token_id, + "emb_layer": emb_layer, + } + + # Select the correct generation function and add model-specific args + if "qwen" in args.model.lower(): + generate_fn_for_timing = generate_mm_qwen2_5_vl + time_generate_args["image_grid_thw"] = inputs["image_grid_thw"] + else: # eagle2 + generate_fn_for_timing = generate_mm + + pyt_timings = time_generate_mm( + generate_fn_for_timing, iterations=args.iterations, **time_generate_args + ) + pyt_stats = record_stats( + "PyTorch", + pyt_timings, + args.precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + + # -------------------------------------------------------------------------# + # 4. Torch-TensorRT compile & run + # -------------------------------------------------------------------------# + + trt_model = copy.deepcopy(model) + # 4.1. Vision model compilation + # --- Add vision model compilation --- # + example_pixel_values = inputs["pixel_values"] + trt_vision = compile_vision_torchtrt(model, args, example_pixel_values) + if "qwen" in args.model.lower(): + trt_model.visual = trt_vision + else: + trt_model.vision_model = trt_vision + + # -------------------------------------------------------------------------# + # 4.2. Language model compilation + # -------------------------------------------------------------------------# + # Register static cache lowering passes if requested + # Cache is not applied to vision model. + if args.cache == "static_v1": + import static_cache_v1 # noqa: F401 + + trt_lm = compile_lm_torchtrt(model, args) + if "qwen" in args.model.lower(): + trt_model.model = trt_lm + else: + trt_model.language_model = trt_lm + + emb_layer = emb_layer.to(DEVICE) + if "qwen" in args.model.lower(): + trt_model.lm_head = trt_model.lm_head.to(DEVICE) + + if args.cache == "static_v1": + if "qwen" in args.model.lower(): + trt_generate = generate_mm_qwen2_5_vl_with_static_cache + else: # eagle2 + trt_generate = generate_mm_with_static_cache + else: + if "qwen" in args.model.lower(): + trt_generate = generate_mm_qwen2_5_vl + else: # eagle2 + trt_generate = generate_mm + + # Prepare args for generate function + generate_args = { + "model": trt_model, + "pixel_values": inputs["pixel_values"], + "input_ids": inputs["input_ids"], + "max_output_seq_length": max_output_len, + "eos_token_id": processor.tokenizer.eos_token_id, + "emb_layer": emb_layer, + } + if "qwen" in args.model.lower(): + generate_args["image_grid_thw"] = inputs["image_grid_thw"] + if args.cache == "static_v1": + generate_args["device"] = DEVICE + + trt_gen_tokens = trt_generate(**generate_args) + + if args.benchmark: + # Prepare args for the timing function + time_generate_args = { + "model": trt_model, + "pixel_values": inputs["pixel_values"].clone(), + "input_ids": inputs["input_ids"].clone(), + "max_output_seq_length": max_output_len, + "eos_token_id": processor.tokenizer.eos_token_id, + "emb_layer": emb_layer, + } + + # Add model-specific args + if "qwen" in args.model.lower(): + time_generate_args["image_grid_thw"] = inputs["image_grid_thw"] + if args.cache == "static_v1": + time_generate_args["device"] = DEVICE + + trt_timings = time_generate_mm( + trt_generate, + iterations=args.iterations, + **time_generate_args, + ) + trt_stats = record_stats( + "TensorRT", + trt_timings, + args.precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + + # -------------------------------------------------------------------------# + # 5. Reporting + # -------------------------------------------------------------------------# + if not args.benchmark: + if args.enable_pytorch_run: + print_outputs("PyTorch", pyt_gen_tokens, processor.tokenizer) + print_outputs("TensorRT", trt_gen_tokens, processor.tokenizer) + + if args.enable_pytorch_run: + print( + f"PyTorch and TensorRT outputs match: " + f"{torch.equal(pyt_gen_tokens, trt_gen_tokens)}" + ) + + if args.benchmark: + if args.enable_pytorch_run: + print("========= PyTorch PERFORMANCE =========\n") + print(pyt_stats) + print("=====================\n") + print("========= TensorRT PERFORMANCE =========\n") + print(trt_stats) diff --git a/tools/llm/static_cache_v1.py b/tools/llm/static_cache_v1.py index b60396c08b..58daacedf5 100644 --- a/tools/llm/static_cache_v1.py +++ b/tools/llm/static_cache_v1.py @@ -201,7 +201,7 @@ def insert_kv_slicing_before_sdpa( args=(slice_7, 3), kwargs={}, ) - # =============================================== # + # Concatenate the sliced tensors to build KV cache cat = gm.graph.create_node( "call_function", diff --git a/tools/llm/test_qwen2.5_components.py b/tools/llm/test_qwen2.5_components.py index 60482bf22d..1c1366bd0c 100644 --- a/tools/llm/test_qwen2.5_components.py +++ b/tools/llm/test_qwen2.5_components.py @@ -16,7 +16,7 @@ # Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from register_sdpa import * +from torchtrt_ext import register_sdpa ATOL = 1e-5 RTOL = 1e-5 diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 2c3434b0ed..09fa662299 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -242,3 +242,849 @@ def record_stats(backend, timings, precision, batch_size=1, compile_time_s=None) "Compile Time(s)": compile_time_s, } return stats + + +def generate_mm( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + device: str = "cuda:0", +): + """Greedy decode for Eagle2-style VLM. + + Parameters + ---------- + model : nn.Module + Must expose vision_model, mlp1, language_model, pixel_shuffle, downsample_ratio, image_token_index. + pixel_values : Tensor | None + Input image batch (B,C,H,W) or None. + input_ids : LongTensor (B, N_prompt) + Text prompt token ids including [IMG] placeholder(s). + max_output_seq_length : int + Maximum tokens to generate **in addition to** the prompt. + eos_token_id : int + Stop generation when all sequences emit EOS. + emb_layer : nn.Embedding + Embedding layer for input_ids. + """ + + vit_embeds = None + + if pixel_values is not None: + # --- Vision encoder timing --- + vis_s = torch.cuda.Event(enable_timing=True) + vis_e = torch.cuda.Event(enable_timing=True) + vis_s.record() + vit_out = model.vision_model(pixel_values) + vis_e.record() + torch.cuda.synchronize() + + vit_embeds = ( + vit_out.last_hidden_state + if hasattr(vit_out, "last_hidden_state") + else vit_out + ) + + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = model.pixel_shuffle( + vit_embeds, scale_factor=model.downsample_ratio + ) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + vit_embeds = model.mlp1(vit_embeds) + + # 2) Text token embeddings + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat_emb = seq_embeds.view(B * N, C) + + mask = seq_tokens.view(B * N) == model.image_token_index + try: + flat_emb[mask] = vit_embeds.reshape(-1, C).to(flat_emb.dtype)[: mask.sum()] + except Exception: + # Fallback in unlikely size-mismatch cases + flat_emb[mask] = vit_embeds.reshape(-1, C)[: mask.sum()].to(flat_emb.dtype) + seq_embeds = flat_emb.view(B, N, C) + + # ───────────────────────────────── Greedy loop ─────────────────────────────────────────────────── + isl = seq_tokens.shape[1] + osl = max_output_seq_length - isl + + generated = 0 + + while generated < osl: + cur_embeds = seq_embeds # full seq first step or cache off + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + with torch.no_grad(): + logits = model.language_model( + inputs_embeds=cur_embeds, position_ids=position_ids + ) + if hasattr(logits, "logits"): + logits = logits.logits + + next_tok = torch.argmax(logits[:, -1, :], dim=-1) # (B,) + # append token & embed + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=-1) + seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1) + + generated += 1 + if (next_tok == eos_token_id).all(): + break + + return seq_tokens[:, input_ids.shape[1] :] + + +@torch.inference_mode() +def generate_mm_with_static_cache( + model, # Complete VLM module + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, # (B, N_prompt) + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + device: str = "cuda:0", +) -> torch.LongTensor: # (B, N_prompt + new) + """ + Greedy Decoder for multimodal VLM (using static KV-cache v1). + Basic structure is identical to LM version (generate_with_static_cache) but + * Input is `inputs_embeds` + * Vision tokens are sent together only in the first step + """ + + # ───────────────────── Vision encoding ───────────────────── + vit_embeds = None + if pixel_values is not None: + vit_latent = model.vision_model(pixel_values) + vit_embeds = ( + vit_latent.last_hidden_state + if hasattr(vit_latent, "last_hidden_state") + else vit_latent + ) + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.view(vit_embeds.size(0), h, w, -1) + vit_embeds = model.pixel_shuffle(vit_embeds, model.downsample_ratio) + vit_embeds = vit_embeds.view(vit_embeds.size(0), -1, vit_embeds.size(-1)) + vit_embeds = model.mlp1(vit_embeds) # (B, N_img, C) + + # ───────────────────── Text embedding & [IMG] replacement ───────────── + seq_tokens = input_ids.clone() # (B, N_txt) + seq_embeds = emb_layer(seq_tokens) # (B, N_txt, C) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.image_token_index + flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] + seq_embeds = flat.view(B, N, C) + + # ───────────────────── KV-cache initialization ───────────────────── + kv_cache = get_zeroed_static_cache_inputs(model.language_model) + start_idx = 0 # First token index + end_idx = seq_embeds.size(1) # Prompt length + generated = 0 + max_total_len = max_output_seq_length + output_tokens = seq_tokens.clone() + + # ───────────────────── Greedy loop ─────────────────────── + while output_tokens.size(1) < max_total_len: + + # When using static cache: + # - First step: Use full prompt embedding + # - Subsequent steps: Use only new token embedding (KV cache remembers previous tokens) + cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] + + # position_ids: Same pattern as generate_with_static_cache + # - First step: Position of entire sequence + # - Subsequent steps: Position of current token only + if generated == 0: + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + else: + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( + cur_embeds.device + ) + + # is_causal = True if cur_embeds.shape[1] > 1 else False + input_signature = ( + cur_embeds, + position_ids, + *kv_cache, + start_idx, + end_idx, + # is_causal, + ) + + logits_and_kv = model.language_model(*input_signature) + logits, kv_cache = logits_and_kv[0], logits_and_kv[1:] + + next_tok = logits[:, -1, :].argmax(dim=-1) # (B,) + output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) + + # Prepare for next step - Static cache only needs new token + next_embed = emb_layer(next_tok)[:, None, :] # (B, 1, C) + seq_embeds = next_embed # Next step uses only new token + + generated += 1 + start_idx = end_idx + end_idx += 1 + # is_causal = True # Causal mask active from now on + + if (next_tok == eos_token_id).all(): + break + + return output_tokens + + +def generate_mm_with_timing( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + max_new_tokens: int = 64, + use_cache: bool = False, +): + # Create timing events + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + vision_start = torch.cuda.Event(enable_timing=True) + vision_end = torch.cuda.Event(enable_timing=True) + mlp_start = torch.cuda.Event(enable_timing=True) + mlp_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + + overall_start.record() + + vit_embeds = None + if pixel_values is not None: + vision_start.record() + vit_out = model.vision_model(pixel_values) + vision_end.record() + torch.cuda.synchronize() + vision_time = vision_start.elapsed_time(vision_end) + + vit_embeds = ( + vit_out.last_hidden_state + if hasattr(vit_out, "last_hidden_state") + else vit_out + ) + + mlp_start.record() + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = model.pixel_shuffle( + vit_embeds, scale_factor=model.downsample_ratio + ) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + vit_embeds = model.mlp1(vit_embeds) + mlp_end.record() + torch.cuda.synchronize() + mlp_time = mlp_start.elapsed_time(mlp_end) + + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat_emb = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.image_token_index + flat_emb[mask] = vit_embeds.reshape(-1, C).to(flat_emb.dtype)[: mask.sum()] + seq_embeds = flat_emb.view(B, N, C) + + step_times = [] + generated = 0 + past_key_values = None + + while generated < max_new_tokens: + lm_start.record() + cur_embeds = seq_embeds + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + with torch.no_grad(): + logits = model.language_model( + inputs_embeds=cur_embeds, position_ids=position_ids + ) + if hasattr(logits, "logits"): + logits = logits.logits + + next_tok = torch.argmax(logits[:, -1, :], dim=-1) + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) + + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=-1) + seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1) + + generated += 1 + + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + + return seq_tokens, step_times, overall_time, vision_time, mlp_time + + +@torch.inference_mode() +def generate_mm_with_static_cache_timing( + model, # Complete VLM module + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, # (B, N_prompt) + eos_token_id: int, + emb_layer: torch.nn.Embedding, + max_new_tokens: int = 64, + device: str = "cuda:0", +) -> tuple: # (seq_tokens, step_times, overall_time, vision_time, mlp_time) + """ + Greedy Decoder for multimodal VLM (using static KV-cache v1) + detailed timing measurement. + + Returns: + seq_tokens: Generated token sequence + step_times: Language model inference time for each step (ms) + overall_time: Total execution time (ms) + vision_time: Vision encoding time (ms) + mlp_time: MLP processing time (ms) + """ + + # ───────────────────── Create timing events ───────────────────── + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + vision_start = torch.cuda.Event(enable_timing=True) + vision_end = torch.cuda.Event(enable_timing=True) + mlp_start = torch.cuda.Event(enable_timing=True) + mlp_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + + overall_start.record() + + # ───────────────────── Vision encoding ───────────────────── + vit_embeds = None + vision_time = 0.0 + mlp_time = 0.0 + + if pixel_values is not None: + vision_start.record() + vit_latent = model.vision_model(pixel_values) + vision_end.record() + torch.cuda.synchronize() + vision_time = vision_start.elapsed_time(vision_end) + + vit_embeds = ( + vit_latent.last_hidden_state + if hasattr(vit_latent, "last_hidden_state") + else vit_latent + ) + + mlp_start.record() + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.view(vit_embeds.size(0), h, w, -1) + vit_embeds = model.pixel_shuffle(vit_embeds, model.downsample_ratio) + vit_embeds = vit_embeds.view(vit_embeds.size(0), -1, vit_embeds.size(-1)) + vit_embeds = model.mlp1(vit_embeds) # (B, N_img, C) + mlp_end.record() + torch.cuda.synchronize() + mlp_time = mlp_start.elapsed_time(mlp_end) + + # ───────────────────── Text embedding & [IMG] replacement ───────────── + seq_tokens = input_ids.clone() # (B, N_txt) + seq_embeds = emb_layer(seq_tokens) # (B, N_txt, C) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.image_token_index + flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] + seq_embeds = flat.view(B, N, C) + + # ───────────────────── KV-cache initialization ───────────────────── + kv_cache = get_zeroed_static_cache_inputs(model.language_model) + start_idx = 0 # First token index + end_idx = seq_embeds.size(1) # Prompt length + generated = 0 + max_total_len = end_idx + max_new_tokens + output_tokens = seq_tokens.clone() + step_times = [] # Timing for each step + + # ───────────────────── Greedy loop ─────────────────────── + while output_tokens.size(1) < max_total_len: + lm_start.record() + + # When using static cache: + # - First step: Use full prompt embedding + # - Subsequent steps: Use only new token embedding (KV cache remembers previous tokens) + cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] + + # position_ids: Same pattern as generate_with_static_cache + # - First step: Position of entire sequence + # - Subsequent steps: Position of current token only + if generated == 0: + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + else: + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( + cur_embeds.device + ) + + # is_causal = True if cur_embeds.shape[1] > 1 else False + input_signature = ( + cur_embeds, + position_ids, + *kv_cache, + start_idx, + end_idx, + # is_causal, + ) + + logits_and_kv = model.language_model(*input_signature) + logits, kv_cache = logits_and_kv[0], logits_and_kv[1:] + + next_tok = logits[:, -1, :].argmax(dim=-1) # (B,) + output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) + + # Prepare for next step - Static cache only needs new token + next_embed = emb_layer(next_tok)[:, None, :] # (B, 1, C) + seq_embeds = next_embed # Next step uses only new token + + generated += 1 + start_idx = end_idx + end_idx += 1 + + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) + + if (next_tok == eos_token_id).all(): + break + + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + + return output_tokens, step_times, overall_time, vision_time, mlp_time + + +def time_generate_mm( + generate_fn, + iterations=10, + **kwargs, +): + """ + Measure the time for generating a sentence over certain number of iterations. + Accepts generation function arguments via kwargs. + """ + timings = [] + for _ in range(iterations): + start_time = timeit.default_timer() + _ = generate_fn(**kwargs) + torch.cuda.synchronize() + end_time = timeit.default_timer() + timings.append(end_time - start_time) + + return timings + + +def generate_mm_qwen2_5_vl( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, +): + """ + Custom generation function for the Qwen2_5_VLForConditionalGeneration model. + Performs greedy decoding without caching, using inputs_embeds instead of input_ids. + """ + # 1. Calculate image embeddings (if pixel_values are provided) + image_embeds = None + if pixel_values is not None: + image_embeds = model.visual(pixel_values, image_grid_thw) + + # 2. Create initial sequence embeddings + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + # 3. Insert image embeddings at image token positions + if image_embeds is not None: + mask = seq_tokens == model.config.image_token_id + num_image_tokens = mask.sum().item() + if num_image_tokens != image_embeds.shape[0]: + raise ValueError( + f"Number of image tokens ({num_image_tokens}) does not match number of image embeddings ({image_embeds.shape[0]})." + ) + mask_expanded = mask.unsqueeze(-1).expand_as(seq_embeds) + seq_embeds = seq_embeds.masked_scatter( + mask_expanded, image_embeds.to(seq_embeds.dtype) + ) + + osl = max_output_seq_length - seq_tokens.shape[1] + # 5. Greedy generation loop + generated = 0 + while generated < osl: + # 5.1. Calculate position_ids + position_ids = ( + torch.arange( + 0, seq_tokens.size(1), dtype=torch.long, device=seq_tokens.device + ) + .unsqueeze(0) + .expand(seq_embeds.size(0), seq_embeds.size(1)) + ) + + # 5.2. Call the language model + with torch.no_grad(): + outputs = model.model( + inputs_embeds=seq_embeds, + position_ids=position_ids, + ) + hidden_states = ( + outputs + if isinstance(outputs, torch.Tensor) + else outputs.last_hidden_state + ) + + # 5.3. Calculate logits for the last token + logits = model.lm_head(hidden_states[:, -1, :]) + + # 5.4. Select the next token (greedy decoding) + next_tok = torch.argmax(logits, dim=-1) + + # 5.5. Append token and embedding to the sequence + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=1) + next_emb = emb_layer(next_tok)[:, None, :] + seq_embeds = torch.cat([seq_embeds, next_emb], dim=1) + + generated += 1 + + # 6. Return generated tokens (only the part after the prompt) + return seq_tokens[:, input_ids.shape[1] :] + + +def generate_mm_qwen2_5_vl_with_static_cache( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + device: str = "cuda:0", +) -> torch.LongTensor: + """ + Greedy Decoder for Qwen-2.5-VL using static KV-cache. + Identical to `generate_mm_with_static_cache` but adapted for Qwen-2.5-VL's + specific architecture (e.g., separate visual encoder call, lm_head). + """ + # 1. Vision encoding + image_embeds = None + if pixel_values is not None: + image_embeds = model.visual(pixel_values, image_grid_thw) + + # 2. Text embedding & image token replacement + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if image_embeds is not None: + mask = seq_tokens == model.config.image_token_id + num_image_tokens = mask.sum().item() + if num_image_tokens != image_embeds.shape[0]: + raise ValueError( + f"Number of image tokens ({num_image_tokens}) does not match " + f"number of image embeddings ({image_embeds.shape[0]})." + ) + mask_expanded = mask.unsqueeze(-1).expand_as(seq_embeds) + seq_embeds = seq_embeds.masked_scatter( + mask_expanded, image_embeds.to(seq_embeds.dtype) + ) + + # 3. KV-cache initialization + kv_cache = get_zeroed_static_cache_inputs(model.model) + start_idx = 0 + end_idx = seq_embeds.size(1) + generated = 0 + osl = max_output_seq_length - seq_tokens.shape[1] + output_tokens = seq_tokens.clone() + + # 4. Greedy loop + while generated < osl: + cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] + + if generated == 0: + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + else: + + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( + cur_embeds.device + ) + + input_signature = ( + cur_embeds, + position_ids, + *kv_cache, + start_idx, + end_idx, + ) + + outputs_and_kv = model.model(*input_signature) + # With the fix in static_cache_v1.py, the model output is now clean: + # (hidden_state, updated_kv_cache[72]) + hidden_states, kv_cache = outputs_and_kv[0], outputs_and_kv[1:] + + # Use logit_pos to get the correct logit based on whether we padded or not. + logits = model.lm_head(hidden_states[:, -1, :]) + + next_tok = logits.argmax(dim=-1) + output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) + + next_embed = emb_layer(next_tok)[:, None, :] + seq_embeds = next_embed + + generated += 1 + start_idx = end_idx + end_idx += 1 + + return output_tokens + + +def generate_mm_paligemma( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, +): + vit_embeds = None + if pixel_values is not None: + vit_out = model.vision_tower(pixel_values) + vit_embeds = model.multi_modal_projector(vit_out.last_hidden_state) + vit_embeds = vit_embeds / (model.config.text_config.hidden_size**0.5) + + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.config.image_token_index + flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] + seq_embeds = flat.view(B, N, C) + + B = seq_tokens.size(0) + cache_position = torch.arange(seq_tokens.size(1), device=seq_tokens.device) + position_ids = cache_position.unsqueeze(0) + 1 + + generated = 0 + while generated < max_output_seq_length: + causal_mask = model.model._update_causal_mask( + attention_mask=None, + token_type_ids=None, + past_key_values=None, + cache_position=cache_position, + input_tensor=seq_embeds, + is_training=False, + ) + + with torch.no_grad(): + out = model.language_model( + inputs_embeds=seq_embeds, + position_ids=position_ids, + attention_mask=causal_mask, + use_cache=False, + ) + logits = out.last_hidden_state if hasattr(out, "last_hidden_state") else out + + next_tok = torch.argmax(logits[:, -1, :], dim=-1) + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=1) + seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1) + + position_ids = torch.cat([position_ids, position_ids[:, -1:] + 1], dim=1) + cache_position = torch.arange(seq_tokens.size(1), device=seq_tokens.device) + + generated += 1 + if (next_tok == eos_token_id).all(): + break + + return seq_tokens + + +@torch.inference_mode() +def generate_mm_paligemma_with_static_cache( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + device: str = "cuda:0", +) -> torch.LongTensor: + vit_embeds = None + if pixel_values is not None: + vit_latent = model.vision_tower(pixel_values) + vit_embeds = ( + vit_latent.last_hidden_state + if hasattr(vit_latent, "last_hidden_state") + else vit_latent + ) + vit_embeds = model.multi_modal_projector(vit_embeds) + vit_embeds = vit_embeds / (model.config.text_config.hidden_size**0.5) + + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.image_token_index + flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] + seq_embeds = flat.view(B, N, C) + + kv_cache = get_zeroed_static_cache_inputs(model.language_model, device=device) + start_idx = 0 + end_idx = seq_embeds.size(1) + generated = 0 + max_total_len = max_output_seq_length + output_tokens = seq_tokens.clone() + + while output_tokens.size(1) < max_total_len: + cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] + if generated == 0: + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + else: + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( + cur_embeds.device + ) + is_causal = True if cur_embeds.shape[1] > 1 else False + input_signature = ( + cur_embeds, + position_ids, + *kv_cache, + start_idx, + end_idx, + is_causal, + ) + + logits_and_kv = model.language_model(*input_signature) + logits, kv_cache = logits_and_kv[0], logits_and_kv[1:] + + next_tok = logits[:, -1, :].argmax(dim=-1) + output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) + + next_embed = emb_layer(next_tok)[:, None, :] + seq_embeds = next_embed + + generated += 1 + start_idx = end_idx + end_idx += 1 + is_causal = True + + if (next_tok == eos_token_id).all(): + break + + return output_tokens + + +@torch.inference_mode() +def generate_mm_qwen2_5_vl_with_timing( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + max_new_tokens: int = 64, +): + """ + Custom generation function for the Qwen2_5_VLForConditionalGeneration model with timing. + """ + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + vision_start = torch.cuda.Event(enable_timing=True) + vision_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + + overall_start.record() + + vision_time = 0.0 + image_embeds = None + if pixel_values is not None: + vision_start.record() + image_embeds = model.visual(pixel_values, image_grid_thw) + vision_end.record() + torch.cuda.synchronize() + vision_time = vision_start.elapsed_time(vision_end) + + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if image_embeds is not None: + mask = seq_tokens == model.config.image_token_id + num_image_tokens = mask.sum().item() + if num_image_tokens != image_embeds.shape[0]: + raise ValueError( + f"Number of image tokens ({num_image_tokens}) does not match number of image embeddings ({image_embeds.shape[0]})." + ) + mask_expanded = mask.unsqueeze(-1).expand_as(seq_embeds) + seq_embeds = seq_embeds.masked_scatter( + mask_expanded, image_embeds.to(seq_embeds.dtype) + ) + + step_times = [] + generated = 0 + while generated < max_new_tokens: + lm_start.record() + position_ids = ( + torch.arange( + 0, seq_tokens.size(1), dtype=torch.long, device=seq_tokens.device + ) + .unsqueeze(0) + .expand(seq_embeds.size(0), seq_embeds.size(1)) + ) + + with torch.no_grad(): + outputs = model.model( + inputs_embeds=seq_embeds, + position_ids=position_ids, + ) + hidden_states = ( + outputs + if isinstance(outputs, torch.Tensor) + else outputs.last_hidden_state + ) + + logits = model.lm_head(hidden_states[:, -1, :]) + next_tok = torch.argmax(logits, dim=-1) + + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) + + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=1) + next_emb = emb_layer(next_tok)[:, None, :] + seq_embeds = torch.cat([seq_embeds, next_emb], dim=1) + + generated += 1 + + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + + # For Qwen, there is no separate MLP part like in Eagle, so mlp_time is 0. + return seq_tokens, step_times, overall_time, vision_time, 0.0