Skip to content

Commit 10ff9e1

Browse files
committed
feat: Refactor LLM model zoo and add KV cache support (#3527)
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
1 parent b6cf8e5 commit 10ff9e1

21 files changed

+3082
-350
lines changed

docsrc/index.rst

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,10 @@ Model Zoo
140140
* :ref:`torch_compile_resnet`
141141
* :ref:`torch_compile_transformer`
142142
* :ref:`torch_compile_stable_diffusion`
143+
* :ref:`compile_hf_models`
143144
* :ref:`torch_compile_gpt2`
144145
* :ref:`torch_export_gpt2`
145-
* :ref:`torch_export_llama2`
146146
* :ref:`torch_export_sam2`
147-
* :ref:`torch_export_flux_dev`
148147
* :ref:`notebooks`
149148

150149
.. toctree::
@@ -155,11 +154,10 @@ Model Zoo
155154
tutorials/_rendered_examples/dynamo/torch_compile_resnet_example
156155
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
157156
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
157+
tutorials/compile_hf_models
158158
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
159159
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion
160160
tutorials/_rendered_examples/dynamo/torch_compile_gpt2
161-
tutorials/_rendered_examples/dynamo/torch_export_gpt2
162-
tutorials/_rendered_examples/dynamo/torch_export_llama2
163161
tutorials/_rendered_examples/dynamo/torch_export_sam2
164162
tutorials/_rendered_examples/dynamo/torch_export_flux_dev
165163
tutorials/notebooks
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
.. _compile_hf_models:
2+
3+
Compiling LLM models from Huggingface
4+
======================================
5+
6+
This tutorial walks you through how to compile LLM models from Huggingface using Torch-TensorRT. We also introduce KV caching in Torch-TensorRT which can greatly improve the performance of LLM inference.
7+
The code is available in the `tools/llm <https://github.com/pytorch/TensorRT/tree/main/tools/llm>`_ directory. We use the ``run_llm.py`` script to compile the model, generate outputs, and measure the performance.
8+
9+
.. note::
10+
This is an **experimental release** and APIs may change in future versions.
11+
12+
.. note::
13+
The compilation scripts and tutorials for Llama-2-7b-chat-hf and gpt2 models have been consolidated into the unified ``run_llm.py`` script located in the `tools/llm <https://github.com/pytorch/TensorRT/tree/main/tools/llm>`_ directory.
14+
15+
Overview of tools/llm Directory
16+
-------------------------------
17+
18+
The ``tools/llm`` directory provides the following tools to compile LLM models from Huggingface:
19+
20+
* **run_llm.py**: Main entry point for model compilation, generating outputs, and benchmarking
21+
* **Static Cache Utilities**: ``static_cache_v1.py`` and ``static_cache_v2.py`` for KV cache optimization
22+
* **SDPA Attention**: ``sdpa_converter.py`` and ``register_sdpa.py`` for registering scaled dot-product attention converter and lowering pass.
23+
* **Testing Components**: Model-specific test files for validation
24+
* **Utility Functions**: ``utils.py`` and ``cache_utils.py`` for common operations
25+
26+
Supported Models
27+
----------------
28+
We have officially verified support for the following LLM families:
29+
30+
.. list-table::
31+
:widths: 20 40 20 20
32+
:header-rows: 1
33+
34+
* - Model Series
35+
- HuggingFace Model Card
36+
- Precision
37+
- KV Cache Support ?
38+
* - GPT-2
39+
- gpt2
40+
- FP16, FP32
41+
- Yes
42+
* - LLaMA 2
43+
- meta-llama/Llama-2-7b-chat-hf
44+
- FP16, FP32
45+
- Yes
46+
* - LLaMA 3.1
47+
- meta-llama/Llama-3.1-8B-Instruct
48+
- FP16, FP32
49+
- Yes
50+
* - LLaMA 3.2
51+
- | meta-llama/Llama-3.2-1B-Instruct
52+
| meta-llama/Llama-3.2-3B-Instruct
53+
- FP16, FP32
54+
- Yes
55+
* - Qwen 2.5
56+
- | Qwen/Qwen2.5-0.5B-Instruct
57+
| Qwen/Qwen2.5-1.5B-Instruct
58+
| Qwen/Qwen2.5-3B-Instruct
59+
| Qwen/Qwen2.5-7B-Instruct
60+
- FP16, FP32
61+
- Yes
62+
63+
Getting Started with run_llm.py
64+
-------------------------------
65+
66+
The main entry point is ``run_llm.py``, which provides a complete workflow for model compilation and benchmarking.
67+
68+
Basic Usage
69+
^^^^^^^^^^^
70+
71+
.. code-block:: bash
72+
73+
python tools/llm/run_llm.py \
74+
--model meta-llama/Llama-3.2-1B-Instruct \
75+
--prompt "What is parallel programming?" \
76+
--precision FP16 \
77+
--num_tokens 128 \
78+
--cache static_v2 \
79+
--benchmark
80+
81+
Key Arguments
82+
^^^^^^^^^^^^^
83+
84+
* ``--model``: Name or path of the HuggingFace LLM
85+
* ``--tokenizer``: (Optional) Tokenizer name; defaults to model name
86+
* ``--prompt``: Input prompt for text generation
87+
* ``--precision``: Precision mode (``FP16``, ``FP32``)
88+
* ``--num_tokens``: Number of output tokens to generate
89+
* ``--cache``: KV cache type (``static_v1``, ``static_v2``, or empty for no KV caching)
90+
* ``--benchmark``: Enable benchmarking mode for performance comparison
91+
* ``--enable_pytorch_run``: Also run and compare PyTorch baseline
92+
93+
94+
Other Usage Examples
95+
^^^^^^^^^^^^^^^^^^^^
96+
.. code-block:: bash
97+
98+
# Compare different models performance
99+
python tools/llm/run_llm.py --model gpt2 --benchmark --enable_pytorch_run
100+
python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --benchmark --enable_pytorch_run
101+
102+
# Generate the outputs (disable benchmarking) by specifying the number of tokens to generate. Default = 128
103+
python tools/llm/run_llm.py --model gpt2 --prompt "What is parallel programming?" --num_tokens 128
104+
python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --num_tokens 128
105+
106+
# Test different caching approaches
107+
python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v1
108+
python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v2
109+
110+
# Compare FP16 vs FP32 performance
111+
python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP16 --benchmark
112+
python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP32 --benchmark
113+
114+
115+
KV Caching in Torch-TensorRT
116+
---------------------------------
117+
118+
We provide two versions of static KV caching: `static_cache_v1 <https://github.com/pytorch/TensorRT/blob/main/tools/llm/static_cache_v1.py>`_ and `static_cache_v2 <https://github.com/pytorch/TensorRT/blob/main/tools/llm/static_cache_v2.py>`_.
119+
In both implementations, we add static KV cache tensors as model inputs/outputs without storing them as external memory.
120+
The length of KV cache = input sequence length + output sequence length (specified by ``--num_tokens``). The number of heads and head dimension are determined by the model config.
121+
122+
Static Cache v1
123+
^^^^^^^^^^^^^^^^
124+
125+
The ``static_cache_v1.py`` implements KV cache in the model graph as follows:
126+
127+
.. code-block:: python
128+
129+
class StaticCacheV1Model(nn.Module):
130+
def __init__(self):
131+
super().__init__()
132+
133+
def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
134+
# Concatenate new key/value pairs with existing cache
135+
new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2)
136+
new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2)
137+
138+
# Compute attention using the updated cache
139+
attn_output = torch._C._nn.scaled_dot_product_attention(
140+
q,
141+
new_key_cache[:, :, :end_idx, :],
142+
new_value_cache[:, :, :end_idx, :],
143+
dropout_p=0.0,
144+
is_causal=is_causal
145+
)
146+
147+
return attn_output, new_key_cache, new_value_cache
148+
149+
In the above code, we concatenate the new key/value pairs with the existing cache and update it. To compute the attention, we use the updated cache and gather the corresponding keys/values from the cache up until and including the current token index.
150+
The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator ``@_aten_lowering_pass`` when we import the ``static_cache_v1.py`` module.
151+
152+
.. note::
153+
The ``start_idx`` and ``end_idx`` are the start and end indices of the current token in the cache. For prefill phase, ``start_idx`` is 0 and ``end_idx`` is the input sequence length.
154+
For decode phase, ``start_idx`` begins at the input sequence length and ``end_idx`` equals ``start_idx + 1``. The ``start_idx`` is incremented by 1 until the end of the sequence or we reach the maximum number of tokens to generate.
155+
156+
157+
Static Cache v2
158+
^^^^^^^^^^^^^^^^
159+
160+
The ``static_cache_v2.py`` is similar to ``static_cache_v1.py`` but it uses less number of slice operations. It implements KV cache in the model graph as follows:
161+
162+
.. code-block:: python
163+
164+
class StaticCacheV2Model(nn.Module):
165+
def __init__(self):
166+
super().__init__()
167+
168+
def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
169+
concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2)
170+
concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2)
171+
new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2)
172+
new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2)
173+
attn_output = torch._C._nn.scaled_dot_product_attention(
174+
q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal
175+
)
176+
177+
return attn_output, new_key_cache, new_value_cache
178+
179+
In the above code, we concatenate the existing key/value cache with current key/value of the token. We use this to directly compute the attention and update the key/value cache inserting the current key/value.
180+
The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator ``@_aten_lowering_pass`` when we import the ``static_cache_v1.py`` module.
181+
The definitons of ``start_idx`` and ``end_idx`` are the same as ``static_cache_v1.py``.
182+
183+
After the model is compiled with static KV cache, the input signature of the model is changed. The new input signature is ``(input_ids, position_ids, key_cache_0, value_cache_0, ..., start_idx, end_idx)``.
184+
The number of key/value cache tensors is equal to the number of attention heads in the model. We can use the ``generate_with_static_cache`` function to generate the outputs.
185+
186+
Generating Outputs
187+
-------------------
188+
We use custom `generate <https://github.com/pytorch/TensorRT/blob/main/tools/llm/utils.py#L112>`_ function to generate the outputs. This function performs standard autoregressive decoding without KV caching.
189+
There is also a `generate_with_static_cache <https://github.com/pytorch/TensorRT/blob/main/tools/llm/utils.py#L141>`_ function that performs autoregressive decoding with KV caching.
190+
191+
The ``generate_with_static_cache`` function takes care of preparing the inputs to the model compiled with static KV cache.
192+
The model inputs are ``input_ids``, ``position_ids``, ``key_cache_0``, ``value_cache_0``, ...., ``start_idx``, ``end_idx``.
193+
We initialize the key/value cache tensors with zeros and for every token generated, the new key/value cache tensors are the outputs of the model.
194+
195+
SDPA Converter (sdpa_converter.py)
196+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
197+
198+
* Converts scaled dot-product attention operation using TRT Python API.
199+
* Supports causal and standard self-attention.
200+
201+
SDPA Registration (register_sdpa.py)
202+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
203+
204+
* This is a Torch-TensorRT lowering pass that replaces variants of SDPA with ``torch.nn.functional.scaled_dot_product_attention``.
205+
* Registers the SDPA converter which is used for converting ``torch.nn.functional.scaled_dot_product_attention`` operation.
206+
207+
208+
Limitations and Known Issues
209+
----------------------------
210+
211+
* Sliding window attention (used in Gemma3 and Qwen 3 models) is not yet supported
212+
* Some model architectures (e.g. Phi-4) have issues with exporting the torch model.
213+
214+
Requirements
215+
^^^^^^^^^^^^
216+
217+
* Torch-TensorRT 2.8.0 or later
218+
* Transformers v4.52.3

examples/dynamo/torch_export_gpt2.py

Lines changed: 0 additions & 98 deletions
This file was deleted.

0 commit comments

Comments
 (0)