|
| 1 | +.. _compile_hf_models: |
| 2 | + |
| 3 | +Compiling LLM models from Huggingface |
| 4 | +====================================== |
| 5 | + |
| 6 | +This tutorial walks you through how to compile LLM models from Huggingface using Torch-TensorRT. We also introduce KV caching in Torch-TensorRT which can greatly improve the performance of LLM inference. |
| 7 | +The code is available in the `tools/llm <https://github.com/pytorch/TensorRT/tree/main/tools/llm>`_ directory. We use the ``run_llm.py`` script to compile the model, generate outputs, and measure the performance. |
| 8 | + |
| 9 | +.. note:: |
| 10 | + This is an **experimental release** and APIs may change in future versions. |
| 11 | + |
| 12 | +.. note:: |
| 13 | + The compilation scripts and tutorials for Llama-2-7b-chat-hf and gpt2 models have been consolidated into the unified ``run_llm.py`` script located in the `tools/llm <https://github.com/pytorch/TensorRT/tree/main/tools/llm>`_ directory. |
| 14 | + |
| 15 | +Overview of tools/llm Directory |
| 16 | +------------------------------- |
| 17 | + |
| 18 | +The ``tools/llm`` directory provides the following tools to compile LLM models from Huggingface: |
| 19 | + |
| 20 | +* **run_llm.py**: Main entry point for model compilation, generating outputs, and benchmarking |
| 21 | +* **Static Cache Utilities**: ``static_cache_v1.py`` and ``static_cache_v2.py`` for KV cache optimization |
| 22 | +* **SDPA Attention**: ``sdpa_converter.py`` and ``register_sdpa.py`` for registering scaled dot-product attention converter and lowering pass. |
| 23 | +* **Testing Components**: Model-specific test files for validation |
| 24 | +* **Utility Functions**: ``utils.py`` and ``cache_utils.py`` for common operations |
| 25 | + |
| 26 | +Supported Models |
| 27 | +---------------- |
| 28 | +We have officially verified support for the following LLM families: |
| 29 | + |
| 30 | +.. list-table:: |
| 31 | + :widths: 20 40 20 20 |
| 32 | + :header-rows: 1 |
| 33 | + |
| 34 | + * - Model Series |
| 35 | + - HuggingFace Model Card |
| 36 | + - Precision |
| 37 | + - KV Cache Support ? |
| 38 | + * - GPT-2 |
| 39 | + - gpt2 |
| 40 | + - FP16, FP32 |
| 41 | + - Yes |
| 42 | + * - LLaMA 2 |
| 43 | + - meta-llama/Llama-2-7b-chat-hf |
| 44 | + - FP16, FP32 |
| 45 | + - Yes |
| 46 | + * - LLaMA 3.1 |
| 47 | + - meta-llama/Llama-3.1-8B-Instruct |
| 48 | + - FP16, FP32 |
| 49 | + - Yes |
| 50 | + * - LLaMA 3.2 |
| 51 | + - | meta-llama/Llama-3.2-1B-Instruct |
| 52 | + | meta-llama/Llama-3.2-3B-Instruct |
| 53 | + - FP16, FP32 |
| 54 | + - Yes |
| 55 | + * - Qwen 2.5 |
| 56 | + - | Qwen/Qwen2.5-0.5B-Instruct |
| 57 | + | Qwen/Qwen2.5-1.5B-Instruct |
| 58 | + | Qwen/Qwen2.5-3B-Instruct |
| 59 | + | Qwen/Qwen2.5-7B-Instruct |
| 60 | + - FP16, FP32 |
| 61 | + - Yes |
| 62 | + |
| 63 | +Getting Started with run_llm.py |
| 64 | +------------------------------- |
| 65 | + |
| 66 | +The main entry point is ``run_llm.py``, which provides a complete workflow for model compilation and benchmarking. |
| 67 | + |
| 68 | +Basic Usage |
| 69 | +^^^^^^^^^^^ |
| 70 | + |
| 71 | +.. code-block:: bash |
| 72 | +
|
| 73 | + python tools/llm/run_llm.py \ |
| 74 | + --model meta-llama/Llama-3.2-1B-Instruct \ |
| 75 | + --prompt "What is parallel programming?" \ |
| 76 | + --precision FP16 \ |
| 77 | + --num_tokens 128 \ |
| 78 | + --cache static_v2 \ |
| 79 | + --benchmark |
| 80 | +
|
| 81 | +Key Arguments |
| 82 | +^^^^^^^^^^^^^ |
| 83 | + |
| 84 | +* ``--model``: Name or path of the HuggingFace LLM |
| 85 | +* ``--tokenizer``: (Optional) Tokenizer name; defaults to model name |
| 86 | +* ``--prompt``: Input prompt for text generation |
| 87 | +* ``--precision``: Precision mode (``FP16``, ``FP32``) |
| 88 | +* ``--num_tokens``: Number of output tokens to generate |
| 89 | +* ``--cache``: KV cache type (``static_v1``, ``static_v2``, or empty for no KV caching) |
| 90 | +* ``--benchmark``: Enable benchmarking mode for performance comparison |
| 91 | +* ``--enable_pytorch_run``: Also run and compare PyTorch baseline |
| 92 | + |
| 93 | + |
| 94 | +Other Usage Examples |
| 95 | +^^^^^^^^^^^^^^^^^^^^ |
| 96 | +.. code-block:: bash |
| 97 | +
|
| 98 | + # Compare different models performance |
| 99 | + python tools/llm/run_llm.py --model gpt2 --benchmark --enable_pytorch_run |
| 100 | + python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --benchmark --enable_pytorch_run |
| 101 | +
|
| 102 | + # Generate the outputs (disable benchmarking) by specifying the number of tokens to generate. Default = 128 |
| 103 | + python tools/llm/run_llm.py --model gpt2 --prompt "What is parallel programming?" --num_tokens 128 |
| 104 | + python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --num_tokens 128 |
| 105 | +
|
| 106 | + # Test different caching approaches |
| 107 | + python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v1 |
| 108 | + python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v2 |
| 109 | +
|
| 110 | + # Compare FP16 vs FP32 performance |
| 111 | + python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP16 --benchmark |
| 112 | + python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP32 --benchmark |
| 113 | +
|
| 114 | +
|
| 115 | +KV Caching in Torch-TensorRT |
| 116 | +--------------------------------- |
| 117 | + |
| 118 | +We provide two versions of static KV caching: `static_cache_v1 <https://github.com/pytorch/TensorRT/blob/main/tools/llm/static_cache_v1.py>`_ and `static_cache_v2 <https://github.com/pytorch/TensorRT/blob/main/tools/llm/static_cache_v2.py>`_. |
| 119 | +In both implementations, we add static KV cache tensors as model inputs/outputs without storing them as external memory. |
| 120 | +The length of KV cache = input sequence length + output sequence length (specified by ``--num_tokens``). The number of heads and head dimension are determined by the model config. |
| 121 | + |
| 122 | +Static Cache v1 |
| 123 | +^^^^^^^^^^^^^^^^ |
| 124 | + |
| 125 | +The ``static_cache_v1.py`` implements KV cache in the model graph as follows: |
| 126 | + |
| 127 | +.. code-block:: python |
| 128 | +
|
| 129 | + class StaticCacheV1Model(nn.Module): |
| 130 | + def __init__(self): |
| 131 | + super().__init__() |
| 132 | +
|
| 133 | + def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): |
| 134 | + # Concatenate new key/value pairs with existing cache |
| 135 | + new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2) |
| 136 | + new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2) |
| 137 | + |
| 138 | + # Compute attention using the updated cache |
| 139 | + attn_output = torch._C._nn.scaled_dot_product_attention( |
| 140 | + q, |
| 141 | + new_key_cache[:, :, :end_idx, :], |
| 142 | + new_value_cache[:, :, :end_idx, :], |
| 143 | + dropout_p=0.0, |
| 144 | + is_causal=is_causal |
| 145 | + ) |
| 146 | +
|
| 147 | + return attn_output, new_key_cache, new_value_cache |
| 148 | +
|
| 149 | +In the above code, we concatenate the new key/value pairs with the existing cache and update it. To compute the attention, we use the updated cache and gather the corresponding keys/values from the cache up until and including the current token index. |
| 150 | +The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator ``@_aten_lowering_pass`` when we import the ``static_cache_v1.py`` module. |
| 151 | + |
| 152 | +.. note:: |
| 153 | + The ``start_idx`` and ``end_idx`` are the start and end indices of the current token in the cache. For prefill phase, ``start_idx`` is 0 and ``end_idx`` is the input sequence length. |
| 154 | + For decode phase, ``start_idx`` begins at the input sequence length and ``end_idx`` equals ``start_idx + 1``. The ``start_idx`` is incremented by 1 until the end of the sequence or we reach the maximum number of tokens to generate. |
| 155 | + |
| 156 | + |
| 157 | +Static Cache v2 |
| 158 | +^^^^^^^^^^^^^^^^ |
| 159 | + |
| 160 | +The ``static_cache_v2.py`` is similar to ``static_cache_v1.py`` but it uses less number of slice operations. It implements KV cache in the model graph as follows: |
| 161 | + |
| 162 | +.. code-block:: python |
| 163 | +
|
| 164 | + class StaticCacheV2Model(nn.Module): |
| 165 | + def __init__(self): |
| 166 | + super().__init__() |
| 167 | +
|
| 168 | + def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): |
| 169 | + concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) |
| 170 | + concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2) |
| 171 | + new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2) |
| 172 | + new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2) |
| 173 | + attn_output = torch._C._nn.scaled_dot_product_attention( |
| 174 | + q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal |
| 175 | + ) |
| 176 | +
|
| 177 | + return attn_output, new_key_cache, new_value_cache |
| 178 | +
|
| 179 | +In the above code, we concatenate the existing key/value cache with current key/value of the token. We use this to directly compute the attention and update the key/value cache inserting the current key/value. |
| 180 | +The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator ``@_aten_lowering_pass`` when we import the ``static_cache_v1.py`` module. |
| 181 | +The definitons of ``start_idx`` and ``end_idx`` are the same as ``static_cache_v1.py``. |
| 182 | + |
| 183 | +After the model is compiled with static KV cache, the input signature of the model is changed. The new input signature is ``(input_ids, position_ids, key_cache_0, value_cache_0, ..., start_idx, end_idx)``. |
| 184 | +The number of key/value cache tensors is equal to the number of attention heads in the model. We can use the ``generate_with_static_cache`` function to generate the outputs. |
| 185 | + |
| 186 | +Generating Outputs |
| 187 | +------------------- |
| 188 | +We use custom `generate <https://github.com/pytorch/TensorRT/blob/main/tools/llm/utils.py#L112>`_ function to generate the outputs. This function performs standard autoregressive decoding without KV caching. |
| 189 | +There is also a `generate_with_static_cache <https://github.com/pytorch/TensorRT/blob/main/tools/llm/utils.py#L141>`_ function that performs autoregressive decoding with KV caching. |
| 190 | + |
| 191 | +The ``generate_with_static_cache`` function takes care of preparing the inputs to the model compiled with static KV cache. |
| 192 | +The model inputs are ``input_ids``, ``position_ids``, ``key_cache_0``, ``value_cache_0``, ...., ``start_idx``, ``end_idx``. |
| 193 | +We initialize the key/value cache tensors with zeros and for every token generated, the new key/value cache tensors are the outputs of the model. |
| 194 | + |
| 195 | +SDPA Converter (sdpa_converter.py) |
| 196 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 197 | + |
| 198 | +* Converts scaled dot-product attention operation using TRT Python API. |
| 199 | +* Supports causal and standard self-attention. |
| 200 | + |
| 201 | +SDPA Registration (register_sdpa.py) |
| 202 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 203 | + |
| 204 | +* This is a Torch-TensorRT lowering pass that replaces variants of SDPA with ``torch.nn.functional.scaled_dot_product_attention``. |
| 205 | +* Registers the SDPA converter which is used for converting ``torch.nn.functional.scaled_dot_product_attention`` operation. |
| 206 | + |
| 207 | + |
| 208 | +Limitations and Known Issues |
| 209 | +---------------------------- |
| 210 | + |
| 211 | +* Sliding window attention (used in Gemma3 and Qwen 3 models) is not yet supported |
| 212 | +* Some model architectures (e.g. Phi-4) have issues with exporting the torch model. |
| 213 | + |
| 214 | +Requirements |
| 215 | +^^^^^^^^^^^^ |
| 216 | + |
| 217 | +* Torch-TensorRT 2.8.0 or later |
| 218 | +* Transformers v4.52.3 |
0 commit comments