From e4e09bb82964960b4bda44b684d7f0f5e6e933b1 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Wed, 2 Jul 2025 14:59:45 +0000 Subject: [PATCH 1/5] integrated vlm code for benchmark --- tools/llm/run_vlm.py | 387 ++++++++++++++++++++++++++++++++++++ tools/llm/utils.py | 461 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 848 insertions(+) create mode 100644 tools/llm/run_vlm.py diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py new file mode 100644 index 0000000000..f6bd62624f --- /dev/null +++ b/tools/llm/run_vlm.py @@ -0,0 +1,387 @@ +""" +.. _run_vlm: + +Running VLM inference with Torch-TensorRT +========================================================== + +This script mirrors the style and structure of *run_llm.py*, illustrating a +Torch-TensorRT (dynamo backend) workflow for Visual-Language Models (VLMs). +""" + +import argparse +import copy +import os +import sys +from contextlib import nullcontext +from typing import Tuple + +import requests +import torch +import torch_tensorrt +from PIL import Image +from torchtrt_ext import register_sdpa +from transformers import AutoModel, AutoProcessor +from utils import ( + generate_mm, + generate_mm_with_static_cache, + record_stats, + time_generate_mm, +) + +# -----------------------------------------------------------------------------# +# Global configuration +# -----------------------------------------------------------------------------# +DEVICE = torch.device("cuda:0") + +# Register SDPA as a standalone operator. Converter & lowering pass are defined +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +import transformers.models.qwen2.modeling_qwen2 as mq # noqa: E402 + +mq.ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = mq.ALL_ATTENTION_FUNCTIONS["sdpa"] + +# -----------------------------------------------------------------------------# +# Model loading helpers +# -----------------------------------------------------------------------------# + + +def _load_eagle2(device: torch.device, torch_dtype: torch.dtype): + """ + Load Eagle2 model and processor. + + Returns + ------- + tuple[torch.nn.Module, transformers.AutoProcessor, torch.nn.Embedding] + The model, its processor and the language-model input embedding layer. + """ + model_id = "nvidia/Eagle2-2B" + with torch.no_grad(): + model = ( + AutoModel.from_pretrained( + model_id, trust_remote_code=True, torch_dtype=torch_dtype + ) + .eval() + .to(device) + ) + + processor = AutoProcessor.from_pretrained( + model_id, trust_remote_code=True, use_fast=True + ) + if hasattr(processor, "tokenizer"): + processor.tokenizer.padding_side = "left" + + emb_layer = model.language_model.get_input_embeddings().to(torch_dtype).to(device) + return model, processor, emb_layer + + +def _load_model( + model_name: str, device: torch.device, torch_dtype: torch.dtype +) -> Tuple[torch.nn.Module, AutoProcessor, torch.nn.Embedding]: + """Dispatch helper for supported VLMs.""" + if model_name.lower() == "eagle2": + return _load_eagle2(device, torch_dtype) + msg = f"Unsupported model: {model_name}" + raise ValueError(msg) + + +# -----------------------------------------------------------------------------# +# Torch-TensorRT compilation helpers +# -----------------------------------------------------------------------------# + + +class _LMNoCache(torch.nn.Module): + """ + Thin wrapper that exposes a language model via ``inputs_embeds`` without KV-cache. + """ + + def __init__(self, lm): + super().__init__() + self.lm = lm + + def forward(self, inputs_embeds, position_ids): + out = self.lm(inputs_embeds=inputs_embeds, position_ids=position_ids) + return out.logits if hasattr(out, "logits") else out + + +def _compile_eagle2_lm( + language_model: torch.nn.Module, + input_embeds: torch.Tensor, + args: argparse.Namespace, +) -> torch.nn.Module: + """ + Compile Eagle2 language model with Torch-TensorRT. + + The function follows the same precision-specific flag logic used in + *run_llm.py* for consistency. + """ + lm_wrap = _LMNoCache(language_model).to(DEVICE).eval() + max_seq_len = input_embeds.shape[1] + args.num_tokens + + S = torch.export.Dim("seq", min=1, max=max_seq_len) + position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(DEVICE) + dyn_shapes = {"inputs_embeds": {1: S}, "position_ids": {1: S}} + + # Precision-specific flags --------------------------------------------------# + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + else: # FP32 + enabled_precisions = {torch.float32} + + with torch.inference_mode(): + exported = torch.export.export( + lm_wrap, + (input_embeds, position_ids), + dynamic_shapes=dyn_shapes, + strict=False, + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_mod = torch_tensorrt.dynamo.compile( + exported, + inputs=[input_embeds, position_ids], + enabled_precisions=enabled_precisions, + use_explicit_typing=use_explicit_typing, + use_fp32_acc=use_fp32_acc, + device=DEVICE, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + offload_module_to_cpu=True, + min_block_size=args.min_block_size, + ) + return trt_mod + + +def compile_torchtrt( + model: torch.nn.Module, args: argparse.Namespace +) -> torch.nn.Module: + """ + Front-end dispatcher mirroring *run_llm.py*’s `compile_torchtrt`. + + Depending on the target VLM, delegates to the appropriate compile routine. + """ + torch_dtype = { + "FP16": torch.float16, + "BF16": torch.bfloat16, + }.get(args.precision, torch.float32) + + example_embeds = torch.randn( + 1, + 2560, + model.language_model.config.hidden_size, + dtype=torch_dtype, + device=DEVICE, + ) + + if args.model.lower() == "eagle2": + return _compile_eagle2_lm(model.language_model, example_embeds, args) + + msg = f"Unsupported model for compilation: {args.model}" + raise ValueError(msg) + + +# -----------------------------------------------------------------------------# +# Utility helpers +# -----------------------------------------------------------------------------# + + +def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): + """Pretty-print generated text for comparison.""" + print(f"========= {backend_name} =========") + print( + f"{backend_name} model generated text: ", + tokenizer.decode(gen_tokens[0], skip_special_tokens=True), + ) + print("===================================") + + +# -----------------------------------------------------------------------------# +# Main driver +# -----------------------------------------------------------------------------# +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run VLM inference (PyTorch & TensorRT back-ends)" + ) + parser.add_argument("--model", default="eagle2", help="VLM model name") + parser.add_argument("--prompt", default="Describe this image.", help="Prompt text") + parser.add_argument( + "--precision", + default="FP16", + choices=["FP16", "BF16", "FP32"], + help="Computation precision", + ) + parser.add_argument("--iterations", type=int, default=5, help="# iterations") + parser.add_argument("--min_block_size", type=int, default=1, help="Min block size") + parser.add_argument("--num_tokens", type=int, default=128, help="# new tokens") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--isl", type=int, default=2048, help="Input seq length") + parser.add_argument( + "--enable_pytorch_run", + action="store_true", + help="Run the PyTorch baseline as well", + ) + parser.add_argument( + "--cache", + default="", + choices=["", "static_v1"], + help="KV-cache variant to use", + ) + parser.add_argument( + "--debug", action="store_true", help="Enable Torch-TensorRT debug logs" + ) + parser.add_argument( + "--benchmark", action="store_true", help="Enable benchmarking mode" + ) + + args = parser.parse_args() + + # -------------------------------------------------------------------------# + # 1. Model / processor / embeddings + # -------------------------------------------------------------------------# + dtype = { + "FP16": torch.float16, + "BF16": torch.bfloat16, + }.get(args.precision, torch.float32) + + model, processor, emb_layer = _load_model(args.model, DEVICE, dtype) + + # -------------------------------------------------------------------------# + # 2. Input construction (image + text prompt) + # -------------------------------------------------------------------------# + url = "https://cdn.pixabay.com/photo/2019/08/08/23/33/car-4393990_1280.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + if args.benchmark: + prompt_len = args.isl - 1792 - 26 + prompt_txt = " ".join(["token"] * max(prompt_len, 0)) + else: + prompt_txt = args.prompt + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": prompt_txt}, + ], + } + ] + + txt = [ + processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + ] + img_in, vid_in = processor.process_vision_info(messages) + inputs = processor( + text=txt, images=img_in, videos=vid_in, return_tensors="pt", padding=True + ).to(DEVICE) + + max_output_len = inputs["input_ids"].shape[1] + args.num_tokens + + # -------------------------------------------------------------------------# + # 3. Optional: PyTorch baseline + # -------------------------------------------------------------------------# + pyt_gen_tokens = pyt_timings = pyt_stats = None + if args.enable_pytorch_run: + pyt_gen_tokens = generate_mm( + model, + inputs["pixel_values"], + inputs["input_ids"], + max_output_len, + processor.tokenizer.eos_token_id, + emb_layer, + ) + if args.benchmark: + pyt_timings = time_generate_mm( + generate_mm, + model, + inputs["pixel_values"].clone(), + inputs["input_ids"].clone(), + max_output_len, + processor.tokenizer.eos_token_id, + emb_layer, + iterations=args.iterations, + ) + pyt_stats = record_stats( + "PyTorch", + pyt_timings, + args.precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + + # Register static cache lowering passes if requested + if args.cache == "static_v1": + import static_cache_v1 # noqa: F401 + + # -------------------------------------------------------------------------# + # 4. Torch-TensorRT compile & run + # -------------------------------------------------------------------------# + trt_lm = compile_torchtrt(model, args) + trt_model = copy.deepcopy(model) + trt_model.language_model = trt_lm + + emb_layer = emb_layer.to(DEVICE) + + if args.cache == "static_v1": + trt_generate = generate_mm_with_static_cache + else: + trt_generate = generate_mm + + trt_gen_tokens = trt_generate( + trt_model, + inputs["pixel_values"], + inputs["input_ids"], + max_output_len, + processor.tokenizer.eos_token_id, + emb_layer, + DEVICE if args.cache == "static_v1" else None, # device arg only for static_v1 + ) + + if args.benchmark: + trt_timings = time_generate_mm( + trt_generate, + trt_model, + inputs["pixel_values"].clone(), + inputs["input_ids"].clone(), + max_output_len, + processor.tokenizer.eos_token_id, + emb_layer, + iterations=args.iterations, + device=DEVICE if args.cache == "static_v1" else None, + ) + trt_stats = record_stats( + "TensorRT", + trt_timings, + args.precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + + # -------------------------------------------------------------------------# + # 5. Reporting + # -------------------------------------------------------------------------# + if not args.benchmark: + if args.enable_pytorch_run: + print_outputs("PyTorch", pyt_gen_tokens, processor.tokenizer) + print_outputs("TensorRT", trt_gen_tokens, processor.tokenizer) + + if args.enable_pytorch_run: + print( + f"PyTorch and TensorRT outputs match: " + f"{torch.equal(pyt_gen_tokens, trt_gen_tokens)}" + ) + + if args.benchmark: + if args.enable_pytorch_run: + print("========= PyTorch PERFORMANCE =========\n") + print(pyt_stats) + print("=====================\n") + print("========= TensorRT PERFORMANCE =========\n") + print(trt_stats) diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 2c3434b0ed..5e188f0e8b 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -242,3 +242,464 @@ def record_stats(backend, timings, precision, batch_size=1, compile_time_s=None) "Compile Time(s)": compile_time_s, } return stats + + +def generate_mm( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, +): + """Greedy decode for Eagle2-style VLM. + + Parameters + ---------- + model : nn.Module + Must expose vision_model, mlp1, language_model, pixel_shuffle, downsample_ratio, image_token_index. + pixel_values : Tensor | None + Input image batch (B,C,H,W) or None. + input_ids : LongTensor (B, N_prompt) + Text prompt token ids including [IMG] placeholder(s). + max_output_seq_length : int + Maximum tokens to generate **in addition to** the prompt. + eos_token_id : int + Stop generation when all sequences emit EOS. + emb_layer : nn.Embedding + Embedding layer for input_ids. + """ + + vit_embeds = None + + if pixel_values is not None: + # --- Vision encoder timing --- + vis_s = torch.cuda.Event(enable_timing=True) + vis_e = torch.cuda.Event(enable_timing=True) + vis_s.record() + vit_out = model.vision_model(pixel_values) + vis_e.record() + torch.cuda.synchronize() + + vit_embeds = ( + vit_out.last_hidden_state + if hasattr(vit_out, "last_hidden_state") + else vit_out + ) + + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = model.pixel_shuffle( + vit_embeds, scale_factor=model.downsample_ratio + ) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + vit_embeds = model.mlp1(vit_embeds) + + # 2) Text token embeddings + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat_emb = seq_embeds.view(B * N, C) + + mask = seq_tokens.view(B * N) == model.image_token_index + try: + flat_emb[mask] = vit_embeds.reshape(-1, C).to(flat_emb.dtype)[: mask.sum()] + except Exception: + # Fallback in unlikely size-mismatch cases + flat_emb[mask] = vit_embeds.reshape(-1, C)[: mask.sum()].to(flat_emb.dtype) + seq_embeds = flat_emb.view(B, N, C) + + # ───────────────────────────────── Greedy loop ─────────────────────────────────────────────────── + isl = seq_tokens.shape[1] + osl = max_output_seq_length - isl + + generated = 0 + + while generated < osl: + cur_embeds = seq_embeds # full seq first step or cache off + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + with torch.no_grad(): + logits = model.language_model(inputs_embeds=cur_embeds, position_ids=position_ids) + if hasattr(logits, "logits"): + logits = logits.logits + + next_tok = torch.argmax(logits[:, -1, :], dim=-1) # (B,) + # append token & embed + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=-1) + seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1) + + generated += 1 + if (next_tok == eos_token_id).all(): + break + + return seq_tokens[:, input_ids.shape[1] :] + + +@torch.inference_mode() +def generate_mm_with_static_cache( + model, # Complete VLM module + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, # (B, N_prompt) + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + device: str = "cuda:0", +) -> torch.LongTensor: # (B, N_prompt + new) + """ + Greedy Decoder for multimodal VLM (using static KV-cache v1). + Basic structure is identical to LM version (generate_with_static_cache) but + * Input is `inputs_embeds` + * Vision tokens are sent together only in the first step + """ + + # ───────────────────── Vision encoding ───────────────────── + vit_embeds = None + if pixel_values is not None: + vit_latent = model.vision_model(pixel_values) + vit_embeds = ( + vit_latent.last_hidden_state + if hasattr(vit_latent, "last_hidden_state") + else vit_latent + ) + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.view(vit_embeds.size(0), h, w, -1) + vit_embeds = model.pixel_shuffle(vit_embeds, model.downsample_ratio) + vit_embeds = vit_embeds.view(vit_embeds.size(0), -1, vit_embeds.size(-1)) + vit_embeds = model.mlp1(vit_embeds) # (B, N_img, C) + + # ───────────────────── Text embedding & [IMG] replacement ───────────── + seq_tokens = input_ids.clone() # (B, N_txt) + seq_embeds = emb_layer(seq_tokens) # (B, N_txt, C) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.image_token_index + flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] + seq_embeds = flat.view(B, N, C) + + # ───────────────────── KV-cache initialization ───────────────────── + kv_cache = get_zeroed_static_cache_inputs( + model.language_model + ) + start_idx = 0 # First token index + end_idx = seq_embeds.size(1) # Prompt length + generated = 0 + max_total_len = max_output_seq_length + output_tokens = seq_tokens.clone() + + # ───────────────────── Greedy loop ─────────────────────── + while output_tokens.size(1) < max_total_len: + + # When using static cache: + # - First step: Use full prompt embedding + # - Subsequent steps: Use only new token embedding (KV cache remembers previous tokens) + cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] + + # position_ids: Same pattern as generate_with_static_cache + # - First step: Position of entire sequence + # - Subsequent steps: Position of current token only + if generated == 0: + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + else: + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( + cur_embeds.device + ) + + # is_causal = True if cur_embeds.shape[1] > 1 else False + input_signature = ( + cur_embeds, + position_ids, + *kv_cache, + start_idx, + end_idx, + # is_causal, + ) + + logits_and_kv = model.language_model(*input_signature) + logits, kv_cache = logits_and_kv[0], logits_and_kv[1:] + + next_tok = logits[:, -1, :].argmax(dim=-1) # (B,) + output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) + + # Prepare for next step - Static cache only needs new token + next_embed = emb_layer(next_tok)[:, None, :] # (B, 1, C) + seq_embeds = next_embed # Next step uses only new token + + generated += 1 + start_idx = end_idx + end_idx += 1 + # is_causal = True # Causal mask active from now on + + if (next_tok == eos_token_id).all(): + break + + return output_tokens + + +def generate_mm_with_timing( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + use_cache: bool = False, +): + # Create timing events + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + vision_start = torch.cuda.Event(enable_timing=True) + vision_end = torch.cuda.Event(enable_timing=True) + mlp_start = torch.cuda.Event(enable_timing=True) + mlp_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + + overall_start.record() + + vit_embeds = None + if pixel_values is not None: + vision_start.record() + vit_out = model.vision_model(pixel_values) + vision_end.record() + torch.cuda.synchronize() + vision_time = vision_start.elapsed_time(vision_end) + + vit_embeds = ( + vit_out.last_hidden_state + if hasattr(vit_out, "last_hidden_state") + else vit_out + ) + + mlp_start.record() + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = model.pixel_shuffle( + vit_embeds, scale_factor=model.downsample_ratio + ) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + vit_embeds = model.mlp1(vit_embeds) + mlp_end.record() + torch.cuda.synchronize() + mlp_time = mlp_start.elapsed_time(mlp_end) + + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat_emb = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.image_token_index + flat_emb[mask] = vit_embeds.reshape(-1, C).to(flat_emb.dtype)[: mask.sum()] + seq_embeds = flat_emb.view(B, N, C) + + step_times = [] + generated = 0 + past_key_values = None + + while generated < max_output_seq_length: + lm_start.record() + cur_embeds = seq_embeds + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + with torch.no_grad(): + logits = model.language_model( + inputs_embeds=cur_embeds, position_ids=position_ids + ) + if hasattr(logits, "logits"): + logits = logits.logits + + next_tok = torch.argmax(logits[:, -1, :], dim=-1) + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) + + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=-1) + seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1) + + generated += 1 + if (next_tok == eos_token_id).all(): + break + + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + + return seq_tokens, step_times, overall_time, vision_time, mlp_time + + +@torch.inference_mode() +def generate_mm_with_static_cache_timing( + model, # Complete VLM module + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, # (B, N_prompt) + eos_token_id: int, + emb_layer: torch.nn.Embedding, + max_new_tokens: int = 64, + device: str = "cuda:0", +) -> tuple: # (seq_tokens, step_times, overall_time, vision_time, mlp_time) + """ + Greedy Decoder for multimodal VLM (using static KV-cache v1) + detailed timing measurement. + + Returns: + seq_tokens: Generated token sequence + step_times: Language model inference time for each step (ms) + overall_time: Total execution time (ms) + vision_time: Vision encoding time (ms) + mlp_time: MLP processing time (ms) + """ + + # ───────────────────── Create timing events ───────────────────── + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + vision_start = torch.cuda.Event(enable_timing=True) + vision_end = torch.cuda.Event(enable_timing=True) + mlp_start = torch.cuda.Event(enable_timing=True) + mlp_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + + overall_start.record() + + # ───────────────────── Vision encoding ───────────────────── + vit_embeds = None + vision_time = 0.0 + mlp_time = 0.0 + + if pixel_values is not None: + vision_start.record() + vit_latent = model.vision_model(pixel_values) + vision_end.record() + torch.cuda.synchronize() + vision_time = vision_start.elapsed_time(vision_end) + + vit_embeds = ( + vit_latent.last_hidden_state + if hasattr(vit_latent, "last_hidden_state") + else vit_latent + ) + + mlp_start.record() + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.view(vit_embeds.size(0), h, w, -1) + vit_embeds = model.pixel_shuffle(vit_embeds, model.downsample_ratio) + vit_embeds = vit_embeds.view(vit_embeds.size(0), -1, vit_embeds.size(-1)) + vit_embeds = model.mlp1(vit_embeds) # (B, N_img, C) + mlp_end.record() + torch.cuda.synchronize() + mlp_time = mlp_start.elapsed_time(mlp_end) + + # ───────────────────── Text embedding & [IMG] replacement ───────────── + seq_tokens = input_ids.clone() # (B, N_txt) + seq_embeds = emb_layer(seq_tokens) # (B, N_txt, C) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.image_token_index + flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] + seq_embeds = flat.view(B, N, C) + + # ───────────────────── KV-cache initialization ───────────────────── + kv_cache = get_zeroed_static_cache_inputs( + model.language_model + ) + start_idx = 0 # First token index + end_idx = seq_embeds.size(1) # Prompt length + generated = 0 + max_total_len = end_idx + max_new_tokens + output_tokens = seq_tokens.clone() + step_times = [] # Timing for each step + + # ───────────────────── Greedy loop ─────────────────────── + while output_tokens.size(1) < max_total_len: + lm_start.record() + + # When using static cache: + # - First step: Use full prompt embedding + # - Subsequent steps: Use only new token embedding (KV cache remembers previous tokens) + cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] + + # position_ids: Same pattern as generate_with_static_cache + # - First step: Position of entire sequence + # - Subsequent steps: Position of current token only + if generated == 0: + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + else: + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( + cur_embeds.device + ) + + # is_causal = True if cur_embeds.shape[1] > 1 else False + input_signature = ( + cur_embeds, + position_ids, + *kv_cache, + start_idx, + end_idx, + # is_causal, + ) + + logits_and_kv = model.language_model(*input_signature) + logits, kv_cache = logits_and_kv[0], logits_and_kv[1:] + + next_tok = logits[:, -1, :].argmax(dim=-1) # (B,) + output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) + + # Prepare for next step - Static cache only needs new token + next_embed = emb_layer(next_tok)[:, None, :] # (B, 1, C) + seq_embeds = next_embed # Next step uses only new token + + generated += 1 + start_idx = end_idx + end_idx += 1 + + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) + + if (next_tok == eos_token_id).all(): + break + + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + + return output_tokens, step_times, overall_time, vision_time, mlp_time + + +def time_generate_mm( + generate_fn, + model, + pixel_values, + input_ids, + output_seq_length, + eos_token_id, + emb_layer, + iterations=10, + device="cuda:0", +): + """ + Measure the time for generating a sentence over certain number of iterations + """ + timings = [] + for _ in range(iterations): + start_time = timeit.default_timer() + _ = generate_fn( + model, pixel_values, input_ids, output_seq_length, eos_token_id, emb_layer + ) + torch.cuda.synchronize() + end_time = timeit.default_timer() + timings.append(end_time - start_time) + + return timings From 9980c4cc741852a51d13ea698281d14125a2c98b Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Thu, 17 Jul 2025 16:16:22 +0000 Subject: [PATCH 2/5] add vision_model compile --- tools/llm/run_vlm.py | 78 +++++++++++++++++++++++++++++++++++++++++--- tools/llm/utils.py | 17 +++++----- 2 files changed, 82 insertions(+), 13 deletions(-) diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py index f6bd62624f..bae102aef8 100644 --- a/tools/llm/run_vlm.py +++ b/tools/llm/run_vlm.py @@ -185,6 +185,64 @@ def compile_torchtrt( raise ValueError(msg) +def _compile_eagle2_vision( + vision_model: torch.nn.Module, + example_pixel_values: torch.Tensor, + args: argparse.Namespace, +) -> torch.nn.Module: + """ + Compile Eagle2 vision model with Torch-TensorRT. + """ + # Set precision-specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + else: # FP32 + enabled_precisions = {torch.float32} + + with torch.inference_mode(): + exported = torch.export.export( + vision_model, + (example_pixel_values,), + strict=False, + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_mod = torch_tensorrt.dynamo.compile( + exported, + inputs=[example_pixel_values], + enabled_precisions=enabled_precisions, + use_explicit_typing=use_explicit_typing, + use_fp32_acc=use_fp32_acc, + device=DEVICE, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + offload_module_to_cpu=True, + min_block_size=args.min_block_size, + ) + return trt_mod + + +def compile_vision_torchtrt( + model: torch.nn.Module, + args: argparse.Namespace, + example_pixel_values: torch.Tensor, +) -> torch.nn.Module: + """ + Dispatcher function for vision model compilation. + """ + if args.model.lower() == "eagle2": + return _compile_eagle2_vision(model.vision_model, example_pixel_values, args) + else: + raise ValueError(f"Unsupported model: {args.model}") + + # -----------------------------------------------------------------------------# # Utility helpers # -----------------------------------------------------------------------------# @@ -297,6 +355,7 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): processor.tokenizer.eos_token_id, emb_layer, ) + print_outputs("PyTorch", pyt_gen_tokens, processor.tokenizer) if args.benchmark: pyt_timings = time_generate_mm( generate_mm, @@ -316,15 +375,26 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): compile_time_s=None, ) + # -------------------------------------------------------------------------# + # 4. Torch-TensorRT compile & run + # -------------------------------------------------------------------------# + + trt_model = copy.deepcopy(model) + # 4.1. Vision model compilation + # --- Add vision model compilation --- # + example_pixel_values = inputs["pixel_values"] + trt_vision = compile_vision_torchtrt(model, args, example_pixel_values) + trt_model.vision_model = trt_vision + + # -------------------------------------------------------------------------# + # 4.2. Language model compilation + # -------------------------------------------------------------------------# # Register static cache lowering passes if requested + # Cache is not applied to vision model. if args.cache == "static_v1": import static_cache_v1 # noqa: F401 - # -------------------------------------------------------------------------# - # 4. Torch-TensorRT compile & run - # -------------------------------------------------------------------------# trt_lm = compile_torchtrt(model, args) - trt_model = copy.deepcopy(model) trt_model.language_model = trt_lm emb_layer = emb_layer.to(DEVICE) diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 5e188f0e8b..a5b8662f27 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -251,6 +251,7 @@ def generate_mm( max_output_seq_length: int, eos_token_id: int, emb_layer: torch.nn.Embedding, + device: str = "cuda:0", ): """Greedy decode for Eagle2-style VLM. @@ -320,10 +321,12 @@ def generate_mm( while generated < osl: cur_embeds = seq_embeds # full seq first step or cache off position_ids = ( - torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) - ) + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) with torch.no_grad(): - logits = model.language_model(inputs_embeds=cur_embeds, position_ids=position_ids) + logits = model.language_model( + inputs_embeds=cur_embeds, position_ids=position_ids + ) if hasattr(logits, "logits"): logits = logits.logits @@ -383,9 +386,7 @@ def generate_mm_with_static_cache( seq_embeds = flat.view(B, N, C) # ───────────────────── KV-cache initialization ───────────────────── - kv_cache = get_zeroed_static_cache_inputs( - model.language_model - ) + kv_cache = get_zeroed_static_cache_inputs(model.language_model) start_idx = 0 # First token index end_idx = seq_embeds.size(1) # Prompt length generated = 0 @@ -609,9 +610,7 @@ def generate_mm_with_static_cache_timing( seq_embeds = flat.view(B, N, C) # ───────────────────── KV-cache initialization ───────────────────── - kv_cache = get_zeroed_static_cache_inputs( - model.language_model - ) + kv_cache = get_zeroed_static_cache_inputs(model.language_model) start_idx = 0 # First token index end_idx = seq_embeds.size(1) # Prompt length generated = 0 From e5e63e58dc8138798d9a4b29db359c8b327b5b4c Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Mon, 21 Jul 2025 16:20:49 +0000 Subject: [PATCH 3/5] Improve clarity of naming and comments --- tools/llm/run_vlm.py | 183 ++++++++++++++++++++++++++++++------------- 1 file changed, 127 insertions(+), 56 deletions(-) diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py index bae102aef8..0bafd9ecc9 100644 --- a/tools/llm/run_vlm.py +++ b/tools/llm/run_vlm.py @@ -1,11 +1,26 @@ """ .. _run_vlm: -Running VLM inference with Torch-TensorRT +Benchmarking VLM Inference with Torch-TensorRT ========================================================== -This script mirrors the style and structure of *run_llm.py*, illustrating a -Torch-TensorRT (dynamo backend) workflow for Visual-Language Models (VLMs). +This script provides a framework for benchmarking the performance of Visual-Language +Models (VLMs). It optimizes the two most computationally intensive components of a +VLM—the language model and the vision model (image feature extraction)—using +the Torch-TensorRT dynamo backend. + +Key Features: +- **Component-wise Optimization**: Compiles both the language and vision models + separately with Torch-TensorRT to accelerate inference. +- **Performance Benchmarking**: Runs the model for multiple iterations to + measure and compare inference latency against the PyTorch baseline. +- **Output Verification**: Checks for token-level consistency between the optimized + TensorRT model and the original PyTorch model to ensure correctness. +- **KV Cache Testing**: Includes options to test inference with and without + KV caching to evaluate its impact on performance. + +This tool mirrors the style and structure of `run_llm.py`, providing a clear +workflow for VLM optimization and analysis. """ import argparse @@ -20,7 +35,7 @@ import torch_tensorrt from PIL import Image from torchtrt_ext import register_sdpa -from transformers import AutoModel, AutoProcessor +from transformers import AutoConfig, AutoModel, AutoProcessor from utils import ( generate_mm, generate_mm_with_static_cache, @@ -33,11 +48,30 @@ # -----------------------------------------------------------------------------# DEVICE = torch.device("cuda:0") -# Register SDPA as a standalone operator. Converter & lowering pass are defined -sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -import transformers.models.qwen2.modeling_qwen2 as mq # noqa: E402 +# --- WORKAROUND FOR EAGLE2 SDPA COMPILATION --- +# Eagle2's language model (Qwen2) implicitly defaults to "flash_attention_2" +# due to settings in its remote code and config.json. This prevents direct +# compilation with SDPA. To work around this without modifying the library, + +# we "monkey-patch" the global attention function map for Qwen2. +# This ensures that any part of the code (including torch.export) requesting +# "flash_attention_2" will receive the "sdpa" implementation instead. +# This patch is global for the script's execution context. +import transformers.models.qwen2.modeling_qwen2 as mq mq.ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = mq.ALL_ATTENTION_FUNCTIONS["sdpa"] +# --- END WORKAROUND --- + +# --- Model-specific constants for benchmark and compilation --- +# Centralizing these values improves readability and maintainability. +MODEL_CONSTANTS = { + "nvidia/Eagle2-2B": { + "EXAMPLE_SEQLEN": 2560, # A fixed sequence length for creating the example tensor for TRT compilation. + "IMAGE_TOKENS": 1792, # Number of special tokens used to represent the image patch embeddings in the input sequence for Eagle2-2B VLM. + "PROMPT_WRAPPER_TOKENS": 26, # The number of special/processing tokens added by the processor's chat template in benchmark mode. + } +} +# --- END Model-specific constants --- # -----------------------------------------------------------------------------# # Model loading helpers @@ -46,7 +80,7 @@ def _load_eagle2(device: torch.device, torch_dtype: torch.dtype): """ - Load Eagle2 model and processor. + Load nvidia/Eagle2-2B model and processor, ensuring the language model uses SDPA. Returns ------- @@ -57,7 +91,10 @@ def _load_eagle2(device: torch.device, torch_dtype: torch.dtype): with torch.no_grad(): model = ( AutoModel.from_pretrained( - model_id, trust_remote_code=True, torch_dtype=torch_dtype + model_id, + trust_remote_code=True, + torch_dtype=torch_dtype, + # attn_implementation="sdpa" is ignored due to the model's remote code. ) .eval() .to(device) @@ -73,13 +110,68 @@ def _load_eagle2(device: torch.device, torch_dtype: torch.dtype): return model, processor, emb_layer -def _load_model( +def load_model( model_name: str, device: torch.device, torch_dtype: torch.dtype ) -> Tuple[torch.nn.Module, AutoProcessor, torch.nn.Embedding]: """Dispatch helper for supported VLMs.""" - if model_name.lower() == "eagle2": + if model_name == "nvidia/Eagle2-2B": return _load_eagle2(device, torch_dtype) - msg = f"Unsupported model: {model_name}" + msg = ( + f"Unsupported model: '{model_name}'. Supported models are: ['nvidia/Eagle2-2B']" + ) + raise ValueError(msg) + + +# -----------------------------------------------------------------------------# +# Input loading helpers +# -----------------------------------------------------------------------------# + + +def _load_inputs_eagle2(args: argparse.Namespace, processor, device: torch.device): + """ + Loads the input dictionary for the Eagle2 model. + """ + url = "https://cdn.pixabay.com/photo/2019/08/08/23/33/car-4393990_1280.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + if args.benchmark: + model_constants = MODEL_CONSTANTS[args.model] + image_tokens = model_constants["IMAGE_TOKENS"] + wrapper_tokens = model_constants["PROMPT_WRAPPER_TOKENS"] + + prompt_len = args.isl - image_tokens - wrapper_tokens + prompt_txt = " ".join(["token"] * max(prompt_len, 0)) + else: + prompt_txt = args.prompt + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": prompt_txt}, + ], + } + ] + + txt = [ + processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + ] + img_in, vid_in = processor.process_vision_info(messages) + inputs = processor( + text=txt, images=img_in, videos=vid_in, return_tensors="pt", padding=True + ).to(device) + return inputs + + +def load_inputs(args: argparse.Namespace, processor, device: torch.device): + """Dispatch helper for input loading for supported VLMs.""" + if args.model == "nvidia/Eagle2-2B": + return _load_inputs_eagle2(args, processor, device) + + msg = f"Unsupported model for input loading: '{args.model}'. Supported models are: ['nvidia/Eagle2-2B']" raise ValueError(msg) @@ -116,9 +208,9 @@ def _compile_eagle2_lm( lm_wrap = _LMNoCache(language_model).to(DEVICE).eval() max_seq_len = input_embeds.shape[1] + args.num_tokens - S = torch.export.Dim("seq", min=1, max=max_seq_len) + seq_len = torch.export.Dim("seq", min=1, max=max_seq_len) position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(DEVICE) - dyn_shapes = {"inputs_embeds": {1: S}, "position_ids": {1: S}} + dyn_shapes = {"inputs_embeds": {1: seq_len}, "position_ids": {1: seq_len}} # Precision-specific flags --------------------------------------------------# use_fp32_acc = False @@ -133,7 +225,7 @@ def _compile_eagle2_lm( enabled_precisions = {torch.float32} with torch.inference_mode(): - exported = torch.export.export( + exported_program = torch.export.export( lm_wrap, (input_embeds, position_ids), dynamic_shapes=dyn_shapes, @@ -142,7 +234,7 @@ def _compile_eagle2_lm( with torch_tensorrt.logging.debug() if args.debug else nullcontext(): trt_mod = torch_tensorrt.dynamo.compile( - exported, + exported_program, inputs=[input_embeds, position_ids], enabled_precisions=enabled_precisions, use_explicit_typing=use_explicit_typing, @@ -157,31 +249,37 @@ def _compile_eagle2_lm( return trt_mod -def compile_torchtrt( +def compile_lm_torchtrt( model: torch.nn.Module, args: argparse.Namespace ) -> torch.nn.Module: """ - Front-end dispatcher mirroring *run_llm.py*’s `compile_torchtrt`. + Compiles the Language Model (LLM) component of the VLM using Torch-TensorRT. - Depending on the target VLM, delegates to the appropriate compile routine. + This function acts as a dispatcher, delegating to the appropriate routine + (e.g., `_compile_eagle2_lm`) based on the target model. """ torch_dtype = { "FP16": torch.float16, "BF16": torch.bfloat16, }.get(args.precision, torch.float32) + model_constants = MODEL_CONSTANTS[args.model] + example_seq_len = model_constants["EXAMPLE_SEQLEN"] + example_embeds = torch.randn( 1, - 2560, + example_seq_len, model.language_model.config.hidden_size, dtype=torch_dtype, device=DEVICE, ) - if args.model.lower() == "eagle2": + if args.model == "nvidia/Eagle2-2B": return _compile_eagle2_lm(model.language_model, example_embeds, args) - msg = f"Unsupported model for compilation: {args.model}" + msg = ( + f"Unsupported model: '{args.model}'. Supported models are: ['nvidia/Eagle2-2B']" + ) raise ValueError(msg) @@ -206,7 +304,7 @@ def _compile_eagle2_vision( enabled_precisions = {torch.float32} with torch.inference_mode(): - exported = torch.export.export( + exported_program = torch.export.export( vision_model, (example_pixel_values,), strict=False, @@ -214,7 +312,7 @@ def _compile_eagle2_vision( with torch_tensorrt.logging.debug() if args.debug else nullcontext(): trt_mod = torch_tensorrt.dynamo.compile( - exported, + exported_program, inputs=[example_pixel_values], enabled_precisions=enabled_precisions, use_explicit_typing=use_explicit_typing, @@ -237,7 +335,7 @@ def compile_vision_torchtrt( """ Dispatcher function for vision model compilation. """ - if args.model.lower() == "eagle2": + if args.model == "nvidia/Eagle2-2B": return _compile_eagle2_vision(model.vision_model, example_pixel_values, args) else: raise ValueError(f"Unsupported model: {args.model}") @@ -265,7 +363,7 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): parser = argparse.ArgumentParser( description="Run VLM inference (PyTorch & TensorRT back-ends)" ) - parser.add_argument("--model", default="eagle2", help="VLM model name") + parser.add_argument("--model", default="nvidia/Eagle2-2B", help="VLM model name") parser.add_argument("--prompt", default="Describe this image.", help="Prompt text") parser.add_argument( "--precision", @@ -306,39 +404,12 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): "BF16": torch.bfloat16, }.get(args.precision, torch.float32) - model, processor, emb_layer = _load_model(args.model, DEVICE, dtype) + model, processor, emb_layer = load_model(args.model, DEVICE, dtype) # -------------------------------------------------------------------------# # 2. Input construction (image + text prompt) # -------------------------------------------------------------------------# - url = "https://cdn.pixabay.com/photo/2019/08/08/23/33/car-4393990_1280.jpg" - image = Image.open(requests.get(url, stream=True).raw) - - if args.benchmark: - prompt_len = args.isl - 1792 - 26 - prompt_txt = " ".join(["token"] * max(prompt_len, 0)) - else: - prompt_txt = args.prompt - - messages = [ - { - "role": "user", - "content": [ - {"type": "image", "image": image}, - {"type": "text", "text": prompt_txt}, - ], - } - ] - - txt = [ - processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - ] - img_in, vid_in = processor.process_vision_info(messages) - inputs = processor( - text=txt, images=img_in, videos=vid_in, return_tensors="pt", padding=True - ).to(DEVICE) + inputs = load_inputs(args, processor, DEVICE) max_output_len = inputs["input_ids"].shape[1] + args.num_tokens @@ -394,7 +465,7 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): if args.cache == "static_v1": import static_cache_v1 # noqa: F401 - trt_lm = compile_torchtrt(model, args) + trt_lm = compile_lm_torchtrt(model, args) trt_model.language_model = trt_lm emb_layer = emb_layer.to(DEVICE) From 5d98dc47581c01c42fb1418f4d632f37ac1a05e8 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Thu, 24 Jul 2025 13:12:32 +0000 Subject: [PATCH 4/5] support qwen2.5_vl with cache --- tools/llm/run_vlm.py | 257 +++++++++++++------ tools/llm/static_cache_v1.py | 16 +- tools/llm/test_qwen2.5_components.py | 2 +- tools/llm/utils.py | 354 ++++++++++++++++++++++++++- 4 files changed, 537 insertions(+), 92 deletions(-) diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py index 0bafd9ecc9..02ffec7875 100644 --- a/tools/llm/run_vlm.py +++ b/tools/llm/run_vlm.py @@ -38,6 +38,8 @@ from transformers import AutoConfig, AutoModel, AutoProcessor from utils import ( generate_mm, + generate_mm_qwen2_5_vl, + generate_mm_qwen2_5_vl_with_static_cache, generate_mm_with_static_cache, record_stats, time_generate_mm, @@ -69,7 +71,12 @@ "EXAMPLE_SEQLEN": 2560, # A fixed sequence length for creating the example tensor for TRT compilation. "IMAGE_TOKENS": 1792, # Number of special tokens used to represent the image patch embeddings in the input sequence for Eagle2-2B VLM. "PROMPT_WRAPPER_TOKENS": 26, # The number of special/processing tokens added by the processor's chat template in benchmark mode. - } + }, + "Qwen/Qwen2.5-VL-3B-Instruct": { + "EXAMPLE_SEQLEN": 2560, + "IMAGE_TOKENS": 391, + "PROMPT_WRAPPER_TOKENS": 21, + }, } # --- END Model-specific constants --- @@ -110,15 +117,30 @@ def _load_eagle2(device: torch.device, torch_dtype: torch.dtype): return model, processor, emb_layer +def _load_qwen2_5_vl(device, torch_dtype: torch.dtype): + """ + Load Qwen2.5-VL model and processor. + """ + from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + model_id = "Qwen/Qwen2.5-VL-3B-Instruct" + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_id, torch_dtype=torch_dtype, device_map=device + ).eval() + processor = AutoProcessor.from_pretrained(model_id) + emb_layer = model.model.get_input_embeddings().to(torch_dtype).to(device) + return model, processor, emb_layer + + def load_model( model_name: str, device: torch.device, torch_dtype: torch.dtype ) -> Tuple[torch.nn.Module, AutoProcessor, torch.nn.Embedding]: """Dispatch helper for supported VLMs.""" if model_name == "nvidia/Eagle2-2B": return _load_eagle2(device, torch_dtype) - msg = ( - f"Unsupported model: '{model_name}'. Supported models are: ['nvidia/Eagle2-2B']" - ) + elif model_name == "Qwen/Qwen2.5-VL-3B-Instruct": + return _load_qwen2_5_vl(device, torch_dtype) + msg = f"Unsupported model: '{model_name}'. Supported models are: ['nvidia/Eagle2-2B', 'Qwen/Qwen2.5-VL-3B-Instruct']" raise ValueError(msg) @@ -127,11 +149,11 @@ def load_model( # -----------------------------------------------------------------------------# -def _load_inputs_eagle2(args: argparse.Namespace, processor, device: torch.device): +def load_inputs(args: argparse.Namespace, processor, device: torch.device): """ - Loads the input dictionary for the Eagle2 model. + Loads and constructs the input dictionary for the specified VLM model. """ - url = "https://cdn.pixabay.com/photo/2019/08/08/23/33/car-4393990_1280.jpg" + url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg" image = Image.open(requests.get(url, stream=True).raw) if args.benchmark: @@ -142,7 +164,7 @@ def _load_inputs_eagle2(args: argparse.Namespace, processor, device: torch.devic prompt_len = args.isl - image_tokens - wrapper_tokens prompt_txt = " ".join(["token"] * max(prompt_len, 0)) else: - prompt_txt = args.prompt + prompt_txt = args.prompt or "Describe this image." messages = [ { @@ -154,25 +176,28 @@ def _load_inputs_eagle2(args: argparse.Namespace, processor, device: torch.devic } ] - txt = [ + text = [ processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) ] - img_in, vid_in = processor.process_vision_info(messages) - inputs = processor( - text=txt, images=img_in, videos=vid_in, return_tensors="pt", padding=True - ).to(device) - return inputs + # --- Model-specific vision processing --- + if "qwen" in args.model.lower(): + from qwen_vl_utils import process_vision_info -def load_inputs(args: argparse.Namespace, processor, device: torch.device): - """Dispatch helper for input loading for supported VLMs.""" - if args.model == "nvidia/Eagle2-2B": - return _load_inputs_eagle2(args, processor, device) + image_inputs, video_inputs = process_vision_info(messages) + else: # eagle2 + image_inputs, video_inputs = processor.process_vision_info(messages) - msg = f"Unsupported model for input loading: '{args.model}'. Supported models are: ['nvidia/Eagle2-2B']" - raise ValueError(msg) + inputs = processor( + text=text, + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ).to(device) + return inputs # -----------------------------------------------------------------------------# @@ -191,28 +216,39 @@ def __init__(self, lm): def forward(self, inputs_embeds, position_ids): out = self.lm(inputs_embeds=inputs_embeds, position_ids=position_ids) - return out.logits if hasattr(out, "logits") else out + return ( + out.logits + if hasattr(out, "logits") + else out.last_hidden_state if hasattr(out, "last_hidden_state") else out + ) -def _compile_eagle2_lm( +def _compile_lm( language_model: torch.nn.Module, input_embeds: torch.Tensor, args: argparse.Namespace, ) -> torch.nn.Module: """ - Compile Eagle2 language model with Torch-TensorRT. - - The function follows the same precision-specific flag logic used in - *run_llm.py* for consistency. + Compile the language model component of a VLM with Torch-TensorRT """ lm_wrap = _LMNoCache(language_model).to(DEVICE).eval() max_seq_len = input_embeds.shape[1] + args.num_tokens - seq_len = torch.export.Dim("seq", min=1, max=max_seq_len) - position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(DEVICE) + # --- Model-specific dynamic shape definition --- + if "qwen" in args.model.lower(): + _seq = torch.export.Dim("_seq", min=1, max=512) + seq_len = 8 * _seq + position_ids = ( + torch.arange(input_embeds.shape[1], device=DEVICE, dtype=torch.long) + .unsqueeze(0) + .expand(input_embeds.size(0), input_embeds.size(1)) + ) + else: # eagle2 + seq_len = torch.export.Dim("seq", min=1, max=max_seq_len) + position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(DEVICE) + dyn_shapes = {"inputs_embeds": {1: seq_len}, "position_ids": {1: seq_len}} - # Precision-specific flags --------------------------------------------------# use_fp32_acc = False use_explicit_typing = False if args.precision == "FP16": @@ -254,33 +290,33 @@ def compile_lm_torchtrt( ) -> torch.nn.Module: """ Compiles the Language Model (LLM) component of the VLM using Torch-TensorRT. - - This function acts as a dispatcher, delegating to the appropriate routine - (e.g., `_compile_eagle2_lm`) based on the target model. """ torch_dtype = { "FP16": torch.float16, "BF16": torch.bfloat16, }.get(args.precision, torch.float32) - model_constants = MODEL_CONSTANTS[args.model] + lm_model = model.model if "qwen" in args.model.lower() else model.language_model + + model_constants = MODEL_CONSTANTS.get( + args.model, {"EXAMPLE_SEQLEN": args.num_tokens} + ) example_seq_len = model_constants["EXAMPLE_SEQLEN"] example_embeds = torch.randn( - 1, + args.batch_size, example_seq_len, - model.language_model.config.hidden_size, + lm_model.config.hidden_size, dtype=torch_dtype, device=DEVICE, ) - if args.model == "nvidia/Eagle2-2B": - return _compile_eagle2_lm(model.language_model, example_embeds, args) - - msg = ( - f"Unsupported model: '{args.model}'. Supported models are: ['nvidia/Eagle2-2B']" - ) - raise ValueError(msg) + # All supported models use the same compilation helper. + if args.model in ["nvidia/Eagle2-2B", "Qwen/Qwen2.5-VL-3B-Instruct"]: + return _compile_lm(lm_model, example_embeds, args) + else: + msg = f"Unsupported model: '{args.model}'. Supported models are: ['nvidia/Eagle2-2B', 'Qwen/Qwen2.5-VL-3B-Instruct']" + raise ValueError(msg) def _compile_eagle2_vision( @@ -337,6 +373,13 @@ def compile_vision_torchtrt( """ if args.model == "nvidia/Eagle2-2B": return _compile_eagle2_vision(model.vision_model, example_pixel_values, args) + elif args.model == "Qwen/Qwen2.5-VL-3B-Instruct": + # TODO: Vision model compilation for Qwen2.5-VL is currently skipped. + # The model's `get_window_index` method uses dynamic Python list operations + # (e.g., .tolist(), .extend()) to process variable-sized image grids for + # windowed attention. These operations are incompatible with torch.export's + # static graph tracing, preventing successful compilation. + return model.visual else: raise ValueError(f"Unsupported model: {args.model}") @@ -363,7 +406,12 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): parser = argparse.ArgumentParser( description="Run VLM inference (PyTorch & TensorRT back-ends)" ) - parser.add_argument("--model", default="nvidia/Eagle2-2B", help="VLM model name") + parser.add_argument( + "--model", + default="Qwen/Qwen2.5-VL-3B-Instruct", + choices=["nvidia/Eagle2-2B", "Qwen/Qwen2.5-VL-3B-Instruct"], + help="VLM model name", + ) parser.add_argument("--prompt", default="Describe this image.", help="Prompt text") parser.add_argument( "--precision", @@ -418,25 +466,46 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): # -------------------------------------------------------------------------# pyt_gen_tokens = pyt_timings = pyt_stats = None if args.enable_pytorch_run: - pyt_gen_tokens = generate_mm( - model, - inputs["pixel_values"], - inputs["input_ids"], - max_output_len, - processor.tokenizer.eos_token_id, - emb_layer, - ) - print_outputs("PyTorch", pyt_gen_tokens, processor.tokenizer) - if args.benchmark: - pyt_timings = time_generate_mm( - generate_mm, + if "qwen" in args.model.lower(): + pyt_gen_tokens = generate_mm_qwen2_5_vl( model, - inputs["pixel_values"].clone(), - inputs["input_ids"].clone(), + inputs["pixel_values"], + inputs["input_ids"], + inputs["image_grid_thw"], max_output_len, processor.tokenizer.eos_token_id, emb_layer, - iterations=args.iterations, + ) + else: # eagle2 + pyt_gen_tokens = generate_mm( + model, + inputs["pixel_values"], + inputs["input_ids"], + max_output_len, + processor.tokenizer.eos_token_id, + emb_layer, + ) + print_outputs("PyTorch", pyt_gen_tokens, processor.tokenizer) + if args.benchmark: + # Prepare args for the timing function + time_generate_args = { + "model": model, + "pixel_values": inputs["pixel_values"].clone(), + "input_ids": inputs["input_ids"].clone(), + "max_output_seq_length": max_output_len, + "eos_token_id": processor.tokenizer.eos_token_id, + "emb_layer": emb_layer, + } + + # Select the correct generation function and add model-specific args + if "qwen" in args.model.lower(): + generate_fn_for_timing = generate_mm_qwen2_5_vl + time_generate_args["image_grid_thw"] = inputs["image_grid_thw"] + else: # eagle2 + generate_fn_for_timing = generate_mm + + pyt_timings = time_generate_mm( + generate_fn_for_timing, iterations=args.iterations, **time_generate_args ) pyt_stats = record_stats( "PyTorch", @@ -455,7 +524,10 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): # --- Add vision model compilation --- # example_pixel_values = inputs["pixel_values"] trt_vision = compile_vision_torchtrt(model, args, example_pixel_values) - trt_model.vision_model = trt_vision + if "qwen" in args.model.lower(): + trt_model.visual = trt_vision + else: + trt_model.vision_model = trt_vision # -------------------------------------------------------------------------# # 4.2. Language model compilation @@ -466,36 +538,63 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): import static_cache_v1 # noqa: F401 trt_lm = compile_lm_torchtrt(model, args) - trt_model.language_model = trt_lm + if "qwen" in args.model.lower(): + trt_model.model = trt_lm + else: + trt_model.language_model = trt_lm emb_layer = emb_layer.to(DEVICE) + if "qwen" in args.model.lower(): + trt_model.lm_head = trt_model.lm_head.to(DEVICE) if args.cache == "static_v1": - trt_generate = generate_mm_with_static_cache + if "qwen" in args.model.lower(): + trt_generate = generate_mm_qwen2_5_vl_with_static_cache + else: # eagle2 + trt_generate = generate_mm_with_static_cache else: - trt_generate = generate_mm - - trt_gen_tokens = trt_generate( - trt_model, - inputs["pixel_values"], - inputs["input_ids"], - max_output_len, - processor.tokenizer.eos_token_id, - emb_layer, - DEVICE if args.cache == "static_v1" else None, # device arg only for static_v1 - ) + if "qwen" in args.model.lower(): + trt_generate = generate_mm_qwen2_5_vl + else: # eagle2 + trt_generate = generate_mm + + # Prepare args for generate function + generate_args = { + "model": trt_model, + "pixel_values": inputs["pixel_values"], + "input_ids": inputs["input_ids"], + "max_output_seq_length": max_output_len, + "eos_token_id": processor.tokenizer.eos_token_id, + "emb_layer": emb_layer, + } + if "qwen" in args.model.lower(): + generate_args["image_grid_thw"] = inputs["image_grid_thw"] + if args.cache == "static_v1": + generate_args["device"] = DEVICE + + trt_gen_tokens = trt_generate(**generate_args) if args.benchmark: + # Prepare args for the timing function + time_generate_args = { + "model": trt_model, + "pixel_values": inputs["pixel_values"].clone(), + "input_ids": inputs["input_ids"].clone(), + "max_output_seq_length": max_output_len, + "eos_token_id": processor.tokenizer.eos_token_id, + "emb_layer": emb_layer, + } + + # Add model-specific args + if "qwen" in args.model.lower(): + time_generate_args["image_grid_thw"] = inputs["image_grid_thw"] + if args.cache == "static_v1": + time_generate_args["device"] = DEVICE + trt_timings = time_generate_mm( trt_generate, - trt_model, - inputs["pixel_values"].clone(), - inputs["input_ids"].clone(), - max_output_len, - processor.tokenizer.eos_token_id, - emb_layer, iterations=args.iterations, - device=DEVICE if args.cache == "static_v1" else None, + **time_generate_args, ) trt_stats = record_stats( "TensorRT", diff --git a/tools/llm/static_cache_v1.py b/tools/llm/static_cache_v1.py index b60396c08b..161d02fe14 100644 --- a/tools/llm/static_cache_v1.py +++ b/tools/llm/static_cache_v1.py @@ -202,11 +202,25 @@ def insert_kv_slicing_before_sdpa( kwargs={}, ) # =============================================== # + # This prevents the cache tensor from growing when padded inputs are used. + update_window_size = gm.graph.create_node( + "call_function", + torch.ops.aten.sub.Tensor, + args=(end_idx_input, start_idx_input), + kwargs={}, + ) + sliced_new_kv = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(current_key_or_value_node, 2, 0, update_window_size), + kwargs={}, + ) + # Concatenate the sliced tensors to build KV cache cat = gm.graph.create_node( "call_function", torch.ops.aten.cat.default, - args=([slice_4, current_key_or_value_node, slice_8], 2), + args=([slice_4, sliced_new_kv, slice_8], 2), kwargs={}, ) # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph diff --git a/tools/llm/test_qwen2.5_components.py b/tools/llm/test_qwen2.5_components.py index 60482bf22d..1c1366bd0c 100644 --- a/tools/llm/test_qwen2.5_components.py +++ b/tools/llm/test_qwen2.5_components.py @@ -16,7 +16,7 @@ # Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from register_sdpa import * +from torchtrt_ext import register_sdpa ATOL = 1e-5 RTOL = 1e-5 diff --git a/tools/llm/utils.py b/tools/llm/utils.py index a5b8662f27..ee8c852d6f 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -679,26 +679,358 @@ def generate_mm_with_static_cache_timing( def time_generate_mm( generate_fn, - model, - pixel_values, - input_ids, - output_seq_length, - eos_token_id, - emb_layer, iterations=10, - device="cuda:0", + **kwargs, ): """ - Measure the time for generating a sentence over certain number of iterations + Measure the time for generating a sentence over certain number of iterations. + Accepts generation function arguments via kwargs. """ timings = [] for _ in range(iterations): start_time = timeit.default_timer() - _ = generate_fn( - model, pixel_values, input_ids, output_seq_length, eos_token_id, emb_layer - ) + _ = generate_fn(**kwargs) torch.cuda.synchronize() end_time = timeit.default_timer() timings.append(end_time - start_time) return timings + + +def generate_mm_qwen2_5_vl( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, +): + """ + Custom generation function for the Qwen2_5_VLForConditionalGeneration model. + Performs greedy decoding without caching, using inputs_embeds instead of input_ids. + """ + # 1. Calculate image embeddings (if pixel_values are provided) + image_embeds = None + if pixel_values is not None: + image_embeds = model.visual(pixel_values, image_grid_thw) + + # 2. Create initial sequence embeddings + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + # 3. Insert image embeddings at image token positions + if image_embeds is not None: + mask = seq_tokens == model.config.image_token_id + num_image_tokens = mask.sum().item() + if num_image_tokens != image_embeds.shape[0]: + raise ValueError( + f"Number of image tokens ({num_image_tokens}) does not match number of image embeddings ({image_embeds.shape[0]})." + ) + mask_expanded = mask.unsqueeze(-1).expand_as(seq_embeds) + seq_embeds = seq_embeds.masked_scatter( + mask_expanded, image_embeds.to(seq_embeds.dtype) + ) + + # 5. Greedy generation loop + generated = 0 + while generated < max_output_seq_length: + # 5.1. Calculate position_ids + position_ids = ( + torch.arange( + 0, seq_tokens.size(1), dtype=torch.long, device=seq_tokens.device + ) + .unsqueeze(0) + .expand(seq_embeds.size(0), seq_embeds.size(1)) + ) + + # 5.2. Call the language model + with torch.no_grad(): + outputs = model.model( + inputs_embeds=seq_embeds, + position_ids=position_ids, + ) + hidden_states = ( + outputs + if isinstance(outputs, torch.Tensor) + else outputs.last_hidden_state + ) + + # 5.3. Calculate logits for the last token + logits = model.lm_head(hidden_states[:, -1, :]) + + # 5.4. Select the next token (greedy decoding) + next_tok = torch.argmax(logits, dim=-1) + + # 5.5. Append token and embedding to the sequence + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=1) + next_emb = emb_layer(next_tok)[:, None, :] + seq_embeds = torch.cat([seq_embeds, next_emb], dim=1) + + generated += 1 + if (next_tok == eos_token_id).all(): + break + + # 6. Return generated tokens (only the part after the prompt) + return seq_tokens[:, input_ids.shape[1] :] + + +def generate_mm_qwen2_5_vl_with_static_cache( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + device: str = "cuda:0", +) -> torch.LongTensor: + """ + Greedy Decoder for Qwen-2.5-VL using static KV-cache. + Identical to `generate_mm_with_static_cache` but adapted for Qwen-2.5-VL's + specific architecture (e.g., separate visual encoder call, lm_head). + """ + # 1. Vision encoding + image_embeds = None + if pixel_values is not None: + image_embeds = model.visual(pixel_values, image_grid_thw) + + # 2. Text embedding & image token replacement + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if image_embeds is not None: + mask = seq_tokens == model.config.image_token_id + num_image_tokens = mask.sum().item() + if num_image_tokens != image_embeds.shape[0]: + raise ValueError( + f"Number of image tokens ({num_image_tokens}) does not match " + f"number of image embeddings ({image_embeds.shape[0]})." + ) + mask_expanded = mask.unsqueeze(-1).expand_as(seq_embeds) + seq_embeds = seq_embeds.masked_scatter( + mask_expanded, image_embeds.to(seq_embeds.dtype) + ) + + # 3. KV-cache initialization + kv_cache = get_zeroed_static_cache_inputs(model.model) + start_idx = 0 + end_idx = seq_embeds.size(1) + generated = 0 + max_total_len = max_output_seq_length + output_tokens = seq_tokens.clone() + + # 4. Greedy loop + while output_tokens.size(1) < max_total_len: + cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] + + if generated == 0: + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + # For the prefill step, the relevant logit is the very last one. + logit_pos = -1 + else: + # --- RUNTIME PADDING FIX for KV Cache Decode --- + # The compiled TensorRT engine has a minimum sequence length requirement (e.g., 16), + # as determined by its optimization profile. The decode step uses a sequence length + # of 1, which violates this profile. + # To resolve this, we manually pad the input tensors to the minimum length (16) + # at runtime before feeding them to the engine. + pad_len = 15 # Pad from 1 to 16 (1 + 15) + + # Pad cur_embeds tensor + padding_tensor_embeds = torch.zeros( + cur_embeds.size(0), + pad_len, + cur_embeds.size(2), + dtype=cur_embeds.dtype, + device=cur_embeds.device, + ) + cur_embeds = torch.cat([cur_embeds, padding_tensor_embeds], dim=1) + + # Pad position_ids tensor + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( + cur_embeds.device + ) + padding_tensor_ids = torch.zeros( + position_ids.size(0), + pad_len, + dtype=position_ids.dtype, + device=position_ids.device, + ) + position_ids = torch.cat([position_ids, padding_tensor_ids], dim=1) + + # Since we padded the sequence, the logit for our actual token is now at position 0. + logit_pos = 0 + + input_signature = ( + cur_embeds, + position_ids, + *kv_cache, + start_idx, + end_idx, + ) + + outputs_and_kv = model.model(*input_signature) + # With the fix in static_cache_v1.py, the model output is now clean: + # (hidden_state, updated_kv_cache[72]) + hidden_states, kv_cache = outputs_and_kv[0], outputs_and_kv[1:] + + # Use logit_pos to get the correct logit based on whether we padded or not. + logits = model.lm_head(hidden_states[:, logit_pos, :]) + + next_tok = logits.argmax(dim=-1) + output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) + + next_embed = emb_layer(next_tok)[:, None, :] + seq_embeds = next_embed + + generated += 1 + start_idx = end_idx + end_idx += 1 + + if (next_tok == eos_token_id).all(): + break + + return output_tokens + + +def generate_mm_paligemma( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, +): + vit_embeds = None + if pixel_values is not None: + vit_out = model.vision_tower(pixel_values) + vit_embeds = model.multi_modal_projector(vit_out.last_hidden_state) + vit_embeds = vit_embeds / (model.config.text_config.hidden_size**0.5) + + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.config.image_token_index + flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] + seq_embeds = flat.view(B, N, C) + + B = seq_tokens.size(0) + cache_position = torch.arange(seq_tokens.size(1), device=seq_tokens.device) + position_ids = cache_position.unsqueeze(0) + 1 + + generated = 0 + while generated < max_output_seq_length: + causal_mask = model.model._update_causal_mask( + attention_mask=None, + token_type_ids=None, + past_key_values=None, + cache_position=cache_position, + input_tensor=seq_embeds, + is_training=False, + ) + + with torch.no_grad(): + out = model.language_model( + inputs_embeds=seq_embeds, + position_ids=position_ids, + attention_mask=causal_mask, + use_cache=False, + ) + logits = out.last_hidden_state if hasattr(out, "last_hidden_state") else out + + next_tok = torch.argmax(logits[:, -1, :], dim=-1) + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=1) + seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1) + + position_ids = torch.cat([position_ids, position_ids[:, -1:] + 1], dim=1) + cache_position = torch.arange(seq_tokens.size(1), device=seq_tokens.device) + + generated += 1 + if (next_tok == eos_token_id).all(): + break + + return seq_tokens + + +@torch.inference_mode() +def generate_mm_paligemma_with_static_cache( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + max_output_seq_length: int, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + device: str = "cuda:0", +) -> torch.LongTensor: + vit_embeds = None + if pixel_values is not None: + vit_latent = model.vision_tower(pixel_values) + vit_embeds = ( + vit_latent.last_hidden_state + if hasattr(vit_latent, "last_hidden_state") + else vit_latent + ) + vit_embeds = model.multi_modal_projector(vit_embeds) + vit_embeds = vit_embeds / (model.config.text_config.hidden_size**0.5) + + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if vit_embeds is not None: + B, N, C = seq_embeds.shape + flat = seq_embeds.view(B * N, C) + mask = seq_tokens.view(B * N) == model.image_token_index + flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()] + seq_embeds = flat.view(B, N, C) + + kv_cache = get_zeroed_static_cache_inputs(model.language_model, device=device) + start_idx = 0 + end_idx = seq_embeds.size(1) + generated = 0 + max_total_len = max_output_seq_length + output_tokens = seq_tokens.clone() + + while output_tokens.size(1) < max_total_len: + cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] + if generated == 0: + position_ids = ( + torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) + ) + else: + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( + cur_embeds.device + ) + is_causal = True if cur_embeds.shape[1] > 1 else False + input_signature = ( + cur_embeds, + position_ids, + *kv_cache, + start_idx, + end_idx, + is_causal, + ) + + logits_and_kv = model.language_model(*input_signature) + logits, kv_cache = logits_and_kv[0], logits_and_kv[1:] + + next_tok = logits[:, -1, :].argmax(dim=-1) + output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) + + next_embed = emb_layer(next_tok)[:, None, :] + seq_embeds = next_embed + + generated += 1 + start_idx = end_idx + end_idx += 1 + is_causal = True + + if (next_tok == eos_token_id).all(): + break + + return output_tokens From cfe1b237291001653669857782fa06a3497d7bef Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Mon, 28 Jul 2025 15:37:42 +0000 Subject: [PATCH 5/5] fix: align ISL/OSL with arguments and remove padding in language model --- tools/llm/run_vlm.py | 30 +++----- tools/llm/static_cache_v1.py | 16 +--- tools/llm/utils.py | 140 ++++++++++++++++++++++++----------- 3 files changed, 107 insertions(+), 79 deletions(-) diff --git a/tools/llm/run_vlm.py b/tools/llm/run_vlm.py index 02ffec7875..be91ac6efe 100644 --- a/tools/llm/run_vlm.py +++ b/tools/llm/run_vlm.py @@ -37,6 +37,7 @@ from torchtrt_ext import register_sdpa from transformers import AutoConfig, AutoModel, AutoProcessor from utils import ( + export_llm, generate_mm, generate_mm_qwen2_5_vl, generate_mm_qwen2_5_vl_with_static_cache, @@ -74,7 +75,7 @@ }, "Qwen/Qwen2.5-VL-3B-Instruct": { "EXAMPLE_SEQLEN": 2560, - "IMAGE_TOKENS": 391, + "IMAGE_TOKENS": 1426, "PROMPT_WRAPPER_TOKENS": 21, }, } @@ -153,7 +154,7 @@ def load_inputs(args: argparse.Namespace, processor, device: torch.device): """ Loads and constructs the input dictionary for the specified VLM model. """ - url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg" + url = "https://www.ilankelman.org/stopsigns/australia.jpg" image = Image.open(requests.get(url, stream=True).raw) if args.benchmark: @@ -197,6 +198,7 @@ def load_inputs(args: argparse.Namespace, processor, device: torch.device): padding=True, return_tensors="pt", ).to(device) + return inputs @@ -234,18 +236,8 @@ def _compile_lm( lm_wrap = _LMNoCache(language_model).to(DEVICE).eval() max_seq_len = input_embeds.shape[1] + args.num_tokens - # --- Model-specific dynamic shape definition --- - if "qwen" in args.model.lower(): - _seq = torch.export.Dim("_seq", min=1, max=512) - seq_len = 8 * _seq - position_ids = ( - torch.arange(input_embeds.shape[1], device=DEVICE, dtype=torch.long) - .unsqueeze(0) - .expand(input_embeds.size(0), input_embeds.size(1)) - ) - else: # eagle2 - seq_len = torch.export.Dim("seq", min=1, max=max_seq_len) - position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(DEVICE) + seq_len = torch.export.Dim("seq", min=1, max=max_seq_len) + position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(DEVICE) dyn_shapes = {"inputs_embeds": {1: seq_len}, "position_ids": {1: seq_len}} @@ -260,13 +252,9 @@ def _compile_lm( else: # FP32 enabled_precisions = {torch.float32} - with torch.inference_mode(): - exported_program = torch.export.export( - lm_wrap, - (input_embeds, position_ids), - dynamic_shapes=dyn_shapes, - strict=False, - ) + exported_program = export_llm( + lm_wrap, input_embeds, min_seq_len=1, max_seq_len=2560 + ) with torch_tensorrt.logging.debug() if args.debug else nullcontext(): trt_mod = torch_tensorrt.dynamo.compile( diff --git a/tools/llm/static_cache_v1.py b/tools/llm/static_cache_v1.py index 161d02fe14..58daacedf5 100644 --- a/tools/llm/static_cache_v1.py +++ b/tools/llm/static_cache_v1.py @@ -201,26 +201,12 @@ def insert_kv_slicing_before_sdpa( args=(slice_7, 3), kwargs={}, ) - # =============================================== # - # This prevents the cache tensor from growing when padded inputs are used. - update_window_size = gm.graph.create_node( - "call_function", - torch.ops.aten.sub.Tensor, - args=(end_idx_input, start_idx_input), - kwargs={}, - ) - sliced_new_kv = gm.graph.create_node( - "call_function", - torch.ops.aten.slice.Tensor, - args=(current_key_or_value_node, 2, 0, update_window_size), - kwargs={}, - ) # Concatenate the sliced tensors to build KV cache cat = gm.graph.create_node( "call_function", torch.ops.aten.cat.default, - args=([slice_4, sliced_new_kv, slice_8], 2), + args=([slice_4, current_key_or_value_node, slice_8], 2), kwargs={}, ) # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph diff --git a/tools/llm/utils.py b/tools/llm/utils.py index ee8c852d6f..09fa662299 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -448,9 +448,9 @@ def generate_mm_with_timing( model, pixel_values: torch.Tensor | None, input_ids: torch.Tensor, - max_output_seq_length: int, eos_token_id: int, emb_layer: torch.nn.Embedding, + max_new_tokens: int = 64, use_cache: bool = False, ): # Create timing events @@ -505,7 +505,7 @@ def generate_mm_with_timing( generated = 0 past_key_values = None - while generated < max_output_seq_length: + while generated < max_new_tokens: lm_start.record() cur_embeds = seq_embeds position_ids = ( @@ -527,8 +527,6 @@ def generate_mm_with_timing( seq_embeds = torch.cat([seq_embeds, emb_layer(next_tok)[:, None, :]], dim=1) generated += 1 - if (next_tok == eos_token_id).all(): - break overall_end.record() torch.cuda.synchronize() @@ -732,9 +730,10 @@ def generate_mm_qwen2_5_vl( mask_expanded, image_embeds.to(seq_embeds.dtype) ) + osl = max_output_seq_length - seq_tokens.shape[1] # 5. Greedy generation loop generated = 0 - while generated < max_output_seq_length: + while generated < osl: # 5.1. Calculate position_ids position_ids = ( torch.arange( @@ -768,8 +767,6 @@ def generate_mm_qwen2_5_vl( seq_embeds = torch.cat([seq_embeds, next_emb], dim=1) generated += 1 - if (next_tok == eos_token_id).all(): - break # 6. Return generated tokens (only the part after the prompt) return seq_tokens[:, input_ids.shape[1] :] @@ -817,52 +814,22 @@ def generate_mm_qwen2_5_vl_with_static_cache( start_idx = 0 end_idx = seq_embeds.size(1) generated = 0 - max_total_len = max_output_seq_length + osl = max_output_seq_length - seq_tokens.shape[1] output_tokens = seq_tokens.clone() # 4. Greedy loop - while output_tokens.size(1) < max_total_len: + while generated < osl: cur_embeds = seq_embeds if generated == 0 else seq_embeds[:, -1:, :] if generated == 0: position_ids = ( torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device) ) - # For the prefill step, the relevant logit is the very last one. - logit_pos = -1 else: - # --- RUNTIME PADDING FIX for KV Cache Decode --- - # The compiled TensorRT engine has a minimum sequence length requirement (e.g., 16), - # as determined by its optimization profile. The decode step uses a sequence length - # of 1, which violates this profile. - # To resolve this, we manually pad the input tensors to the minimum length (16) - # at runtime before feeding them to the engine. - pad_len = 15 # Pad from 1 to 16 (1 + 15) - - # Pad cur_embeds tensor - padding_tensor_embeds = torch.zeros( - cur_embeds.size(0), - pad_len, - cur_embeds.size(2), - dtype=cur_embeds.dtype, - device=cur_embeds.device, - ) - cur_embeds = torch.cat([cur_embeds, padding_tensor_embeds], dim=1) - # Pad position_ids tensor position_ids = torch.tensor([[start_idx]], dtype=torch.int64).to( cur_embeds.device ) - padding_tensor_ids = torch.zeros( - position_ids.size(0), - pad_len, - dtype=position_ids.dtype, - device=position_ids.device, - ) - position_ids = torch.cat([position_ids, padding_tensor_ids], dim=1) - - # Since we padded the sequence, the logit for our actual token is now at position 0. - logit_pos = 0 input_signature = ( cur_embeds, @@ -878,7 +845,7 @@ def generate_mm_qwen2_5_vl_with_static_cache( hidden_states, kv_cache = outputs_and_kv[0], outputs_and_kv[1:] # Use logit_pos to get the correct logit based on whether we padded or not. - logits = model.lm_head(hidden_states[:, logit_pos, :]) + logits = model.lm_head(hidden_states[:, -1, :]) next_tok = logits.argmax(dim=-1) output_tokens = torch.cat([output_tokens, next_tok[:, None]], dim=-1) @@ -890,9 +857,6 @@ def generate_mm_qwen2_5_vl_with_static_cache( start_idx = end_idx end_idx += 1 - if (next_tok == eos_token_id).all(): - break - return output_tokens @@ -1034,3 +998,93 @@ def generate_mm_paligemma_with_static_cache( break return output_tokens + + +@torch.inference_mode() +def generate_mm_qwen2_5_vl_with_timing( + model, + pixel_values: torch.Tensor | None, + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + eos_token_id: int, + emb_layer: torch.nn.Embedding, + max_new_tokens: int = 64, +): + """ + Custom generation function for the Qwen2_5_VLForConditionalGeneration model with timing. + """ + overall_start = torch.cuda.Event(enable_timing=True) + overall_end = torch.cuda.Event(enable_timing=True) + vision_start = torch.cuda.Event(enable_timing=True) + vision_end = torch.cuda.Event(enable_timing=True) + lm_start = torch.cuda.Event(enable_timing=True) + lm_end = torch.cuda.Event(enable_timing=True) + + overall_start.record() + + vision_time = 0.0 + image_embeds = None + if pixel_values is not None: + vision_start.record() + image_embeds = model.visual(pixel_values, image_grid_thw) + vision_end.record() + torch.cuda.synchronize() + vision_time = vision_start.elapsed_time(vision_end) + + seq_tokens = input_ids.clone() + seq_embeds = emb_layer(seq_tokens) + + if image_embeds is not None: + mask = seq_tokens == model.config.image_token_id + num_image_tokens = mask.sum().item() + if num_image_tokens != image_embeds.shape[0]: + raise ValueError( + f"Number of image tokens ({num_image_tokens}) does not match number of image embeddings ({image_embeds.shape[0]})." + ) + mask_expanded = mask.unsqueeze(-1).expand_as(seq_embeds) + seq_embeds = seq_embeds.masked_scatter( + mask_expanded, image_embeds.to(seq_embeds.dtype) + ) + + step_times = [] + generated = 0 + while generated < max_new_tokens: + lm_start.record() + position_ids = ( + torch.arange( + 0, seq_tokens.size(1), dtype=torch.long, device=seq_tokens.device + ) + .unsqueeze(0) + .expand(seq_embeds.size(0), seq_embeds.size(1)) + ) + + with torch.no_grad(): + outputs = model.model( + inputs_embeds=seq_embeds, + position_ids=position_ids, + ) + hidden_states = ( + outputs + if isinstance(outputs, torch.Tensor) + else outputs.last_hidden_state + ) + + logits = model.lm_head(hidden_states[:, -1, :]) + next_tok = torch.argmax(logits, dim=-1) + + lm_end.record() + torch.cuda.synchronize() + step_times.append(lm_start.elapsed_time(lm_end)) + + seq_tokens = torch.cat([seq_tokens, next_tok[:, None]], dim=1) + next_emb = emb_layer(next_tok)[:, None, :] + seq_embeds = torch.cat([seq_embeds, next_emb], dim=1) + + generated += 1 + + overall_end.record() + torch.cuda.synchronize() + overall_time = overall_start.elapsed_time(overall_end) + + # For Qwen, there is no separate MLP part like in Eagle, so mlp_time is 0. + return seq_tokens, step_times, overall_time, vision_time, 0.0