Skip to content

Commit f539b55

Browse files
author
Chengzhe Xu
committed
chore: updates
1 parent 6cbb1bd commit f539b55

File tree

10 files changed

+676
-240
lines changed

10 files changed

+676
-240
lines changed

examples/dynamo/cache_utils.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import torch
2+
from torch.fx import Graph, GraphModule, Node
3+
from typing import Optional, Union, Iterable, List, Tuple
4+
from torch._ops import OpOverloadPacket
5+
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
6+
from torch.fx.passes.shape_prop import _extract_tensor_metadata
7+
from torch.utils._pytree import _LEAF_SPEC
8+
from torch._export.utils import _detect_fake_mode_from_gm
9+
10+
def get_kv_nodes(gm):
11+
"""
12+
Get the key and value nodes from the graph.
13+
"""
14+
kv_nodes = []
15+
for node in gm.graph.nodes:
16+
if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention:
17+
q_node, k_node, v_node = node.args[:3]
18+
kv_nodes.append((k_node, v_node))
19+
return kv_nodes
20+
21+
def get_random_tensor_from_node(node: Node) -> torch.Tensor:
22+
"""
23+
Creates a random tensor based on the shape information in a node's metadata.
24+
For symbolic dimensions, extracts the maximum value from the shape environment.
25+
26+
Args:
27+
node: A torch.fx.Node object with metadata containing tensor information
28+
29+
Returns:
30+
A random tensor with shape matching the node's metadata, or None if no valid
31+
tensor information is found
32+
"""
33+
if "val" not in node.meta:
34+
raise ValueError(f"No tensor information found in node metadata for node: {node}")
35+
36+
fake_tensor = node.meta["val"]
37+
shape = []
38+
39+
# Iterate through each dimension and handle symbolic dimensions
40+
for dim in fake_tensor.shape:
41+
if isinstance(dim, torch.SymInt):
42+
# Extract the maximum value from the shape environment
43+
max_val = dim.node.hint
44+
shape.append(max_val)
45+
else:
46+
shape.append(dim)
47+
48+
# Create a random tensor with the determined shape
49+
dtype = fake_tensor.dtype
50+
device = fake_tensor.device
51+
random_tensor = torch.rand(shape, dtype=dtype, device=device)
52+
53+
return random_tensor
54+
55+
def create_random_output_tensors(nodes: List[Node]) -> List[torch.Tensor]:
56+
"""
57+
Creates random tensors based on the shape information in node metadata.
58+
For symbolic dimensions, extracts the maximum value from the shape environment.
59+
60+
Args:
61+
nodes: List of torch.fx.Node objects with metadata
62+
63+
Returns:
64+
List of random tensors with shapes matching the nodes' metadata
65+
"""
66+
random_tensors = []
67+
68+
for node in nodes:
69+
if isinstance(node, Node):
70+
node_tensor = get_random_tensor_from_node(node)
71+
elif isinstance(node, tuple):
72+
node_tensor_list = []
73+
for n in node:
74+
random_tensor = get_random_tensor_from_node(n)
75+
node_tensor_list.append(random_tensor)
76+
node_tensor = tuple(node_tensor_list)
77+
78+
random_tensors.append(node_tensor)
79+
80+
return random_tensors
81+
82+
def add_graph_input(
83+
gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None
84+
) -> Node:
85+
"""Add a graph input to the given GraphModule and return the newly created node.
86+
87+
NOTE: function does NOT do any graph canonicalization. This is left to the user!
88+
89+
Args:
90+
gm (GraphModule): The GraphModule to add the input to.
91+
name (str): The name of the input.
92+
val (torch.Tensor): An example tensor to use for the input.
93+
dynamic_shape: The dynamic shape of the input tensor [NOT SUPPORTED YET]
94+
"""
95+
# check that no dynamic shape is provided...
96+
if dynamic_shape:
97+
raise NotImplementedError("Dynamic shape not supported for adding graph inputs")
98+
99+
# extract graph and input spec
100+
graph: Graph = gm.graph
101+
102+
in_spec = graph._codegen.pytree_info.in_spec
103+
in_spec_for_args = in_spec.children_specs[0]
104+
orig_args = graph._codegen.pytree_info.orig_args
105+
assert in_spec_for_args.type is tuple
106+
107+
# insert input node after currently last input node
108+
node_last_input = graph.find_nodes(op="placeholder", sort=True)[-1]
109+
with graph.inserting_after(node_last_input):
110+
in_node = graph.placeholder(name)
111+
in_spec_for_args.children_specs.append(_LEAF_SPEC)
112+
orig_args.append(f"arg_{name}")
113+
114+
# update pytree info recursively with __post_init__ starting at leaves
115+
def call_post_init(spec):
116+
for child_spec in spec.children_specs:
117+
call_post_init(child_spec)
118+
spec.__post_init__()
119+
120+
call_post_init(in_spec)
121+
122+
# set fake tensor information if all required information is available
123+
fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm)
124+
if fake_mode and val is not None and isinstance(val, torch.Tensor):
125+
if isinstance(val, FakeTensor):
126+
fake_tensor = val
127+
else:
128+
fake_tensor: FakeTensor = fake_mode.from_tensor(val, static_shapes=True)
129+
in_node.meta["val"] = fake_tensor
130+
in_node.meta["tensor_meta"] = _extract_tensor_metadata(fake_tensor)
131+
132+
# return new node...
133+
return in_node
134+
135+
def is_op(node: Node, ops: Union[OpOverloadPacket, Iterable[OpOverloadPacket]]) -> bool:
136+
"""Check if the node is a call to one of the ops."""
137+
if node.op != "call_function":
138+
return False
139+
# check if it's a single op that's provided
140+
if isinstance(ops, OpOverloadPacket):
141+
ops = [ops]
142+
143+
# check if it's the op itself instead of an overload
144+
if any(node.target == op for op in ops):
145+
return True
146+
147+
return False
148+
149+
def get_all_input_output_nodes(graph: Graph) -> Tuple[List[Node], List[Node]]:
150+
input_nodes: List[Node] = graph.find_nodes(op="placeholder")
151+
output_nodes: List[Node] = graph.find_nodes(op="output")
152+
return (input_nodes, output_nodes)

examples/dynamo/dynamic_cache.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
clean_up_graph_after_modifications,
1515
)
1616

17-
from .cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes
17+
from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes, is_op
1818
import tensorrt
1919
import torch.utils._pytree as pytree
2020
logger = logging.getLogger(__name__)
@@ -146,23 +146,7 @@ def get_static_tensor(tensor: torch.Tensor):
146146
v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val)
147147
kv_inputs.append((k_input, v_input))
148148

149-
# Add start_idx and end_idx as inputs
150-
start_idx_input = add_graph_input(gm, "start_idx")
151-
end_idx_input = add_graph_input(gm, "end_idx")
152-
return kv_inputs, start_idx_input, end_idx_input
153-
154-
def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]):
155-
"""
156-
Insert slicing operations before each scaled_dot_product_attention operation.
157-
"""
158-
pass
159-
# Find all nodes with scaled_dot_product_attention
160-
sdpa_nodes = []
161-
for node in gm.graph.nodes:
162-
if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention:
163-
sdpa_nodes.append(node)
164-
165-
for idx, sdpa_node in enumerate(sdpa_nodes):
149+
return kv_inputs
166150

167151

168152
def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]):
@@ -181,41 +165,44 @@ def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten
181165
if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention:
182166
sdpa_nodes.append(node)
183167

168+
# Get the is_causal input node
169+
is_causal_node = next((node for node in gm.graph.nodes if node.op == "placeholder" and node.name == "is_causal"), None)
170+
184171
# For each SDPA node, insert a torch.cond operation before it
185172
for idx, sdpa_node in enumerate(sdpa_nodes):
186173

187174
with gm.graph.inserting_before(sdpa_node):
188-
pred_node = add_graph_input(gm, "is_generate", torch.tensor(False, dtype=torch.bool))
175+
# pred_node = add_graph_input(gm, "is_generate", torch.tensor(False, dtype=torch.bool))
189176
q_node, k_node, v_node = sdpa_node.args[:3]
190177
incoming_key, incoming_value = incoming_keys_values[idx]
191178
# Create nodes for concatenating k with incoming_key and v with incoming_value
192179
concatenated_k_node = gm.graph.create_node(
193180
"call_function",
194181
torch.ops.aten.cat.default,
195-
args=([k_node, incoming_key], 2), # Concatenate along sequence length dimension
182+
args=([incoming_key, k_node], 2), # Concatenate along sequence length dimension
196183
kwargs={}
197184
)
198185
concatenated_v_node = gm.graph.create_node(
199186
"call_function",
200187
torch.ops.aten.cat.default,
201-
args=([v_node, incoming_value], 2), # Concatenate along sequence length dimension
188+
args=([incoming_value, v_node], 2), # Concatenate along sequence length dimension
202189
kwargs={}
203190
)
204191

205192
# Create the torch.cond node
206193
cond_k_node = gm.graph.create_node(
207194
"call_function",
208195
torch.ops.higher_order.cond,
209-
args=(pred_node, concatenated_k_node, k_node),
196+
args=(is_causal_node, concatenated_k_node, k_node),
210197
)
211198

212199
cond_v_node = gm.graph.create_node(
213200
"call_function",
214201
torch.ops.higher_order.cond,
215-
args=(pred_node, concatenated_v_node, v_node),
202+
args=(is_causal_node, concatenated_v_node, v_node),
216203
)
217204

218-
sdpa_node.args = (q_node, cond_k_node, cond_v_node)
205+
sdpa_node.args = (q_node, cond_k_node, cond_v_node) + sdpa_node.args[3:]
219206

220207
return gm
221208

@@ -229,13 +216,13 @@ def insert_dynamic_kv_cache(
229216
"""Perform insertion of kv-caches and attention kernel."""
230217

231218
# Add static key and value as inputs to the graph
232-
kv_inputs, start_idx_input, end_idx_input = add_kv_and_indices_as_inputs(gm, fixed_kv=True)
219+
kv_inputs = add_kv_and_indices_as_inputs(gm, fixed_kv=True)
233220

234-
# Call the function to add QKV as outputs
235-
logits_keys_values = add_kv_as_outputs(gm, start_idx_input, end_idx_input)
221+
# Call the function to add KV as outputs
222+
logits_keys_values = add_kv_as_outputs(gm)
236223

237-
gm = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input)
238-
# gm = insert_torch_cond_before_sdpa(gm, kv_inputs)
224+
# Insert torch.cond before each SDPA node which acts toggles between prefill and generate phases
225+
gm = insert_torch_cond_before_sdpa(gm, kv_inputs)
239226

240227
gm = clean_up_graph_after_modifications(gm)
241228

examples/dynamo/llama3_trt.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch_tensorrt
2020
from transformers import AutoModelForCausalLM, AutoTokenizer
2121
from contextlib import nullcontext
22-
from utils import export_llm, generate, recordStats, time_generate, generate_with_kv_cache
22+
from utils import export_llm, generate, recordStats, time_generate, generate_with_kv_cache, get_zeroed_kv_cache_inputs
2323

2424

2525
DEVICE = torch.device("cuda:0")
@@ -43,7 +43,7 @@ def get_model(args):
4343
args.model,
4444
use_cache=False,
4545
attn_implementation="sdpa",
46-
# num_hidden_layers=1
46+
num_hidden_layers=1
4747
)
4848
.eval()
4949
.cuda()
@@ -194,9 +194,10 @@ def measure_perf(trt_model, input_signature, backend_name):
194194
help="Enable pytorch run (default: False)"
195195
)
196196
arg_parser.add_argument(
197-
"--kv_cache",
198-
action="store_true",
199-
help="Enable kv_cache (default: False)"
197+
"--cache",
198+
type=str,
199+
default="static",
200+
help="Type of KV cache to use",
200201
)
201202
arg_parser.add_argument(
202203
"--cudagraph",
@@ -220,9 +221,9 @@ def measure_perf(trt_model, input_signature, backend_name):
220221
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
221222

222223
prompt = "What is parallel programming ?"
224+
# prompt = "What is the capital of France ?"
223225
model_inputs = tokenizer(prompt, return_tensors="pt")
224226
input_ids = model_inputs["input_ids"].to(DEVICE)
225-
226227
# Prepare input prompt
227228
# word = "What"
228229
# word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence
@@ -252,18 +253,67 @@ def measure_perf(trt_model, input_signature, backend_name):
252253
)
253254

254255
# TRT
256+
pyt_logits_tok1 = model.cuda()(input_ids)
257+
next_tokens = torch.argmax(pyt_logits_tok1.logits[:, -1, :], dim=-1)
258+
input_seq = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
259+
pyt_logits_tok2 = model.cuda()(input_seq)
255260
from lower_sdpa import *
256-
if args.kv_cache:
257-
# This import is required to register static/dynamic KV cache transformations as lowering passes
258-
from static_cache import *
261+
if args.cache == "static":
262+
# This import is required to register static KV cache transformations as lowering passes
263+
from static_cache2 import *
264+
trt_model = compile_torchtrt(model, input_ids, args)
265+
kv_cache = get_zeroed_kv_cache_inputs(trt_model)
266+
267+
# First token generation
268+
pyt_keys = torch.load("key.pt"); pyt_values = torch.load("value.pt")
269+
trt_logits, key_cache, value_cache, trt_keys_1, trt_values_1 = trt_model(input_ids.clone(), True, *kv_cache, 0, input_ids.shape[1])
270+
print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits_tok1.logits - trt_logits))}")
271+
print(f"Diff between pyt and trt keys: {torch.mean(torch.abs(pyt_keys - trt_keys_1))}")
272+
print(f"Diff between pyt and trt keys in cache: {torch.mean(torch.abs(pyt_keys - key_cache[:, :, :-2, :]))}")
273+
print(f"Diff between pyt and trt values: {torch.mean(torch.abs(pyt_values - trt_values_1))}")
274+
print(f"Diff between pyt and trt values in cache: {torch.mean(torch.abs(pyt_values - value_cache[:, :, :-2, :]))}")
275+
next_tokens = torch.argmax(trt_logits[:, -1, :], dim=-1)
276+
277+
# Second token generation
278+
trt_logits_2, key_cache2, value_cache2, trt_keys_2, trt_values_2 = trt_model(next_tokens[:, None], False, key_cache.clone(), value_cache.clone(), input_ids.shape[1], input_ids.shape[1]+1)
279+
pyt_keys2 = torch.load("key2.pt"); pyt_values2 = torch.load("value2.pt")
280+
print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits_tok2.logits[:, -1:, :] - trt_logits_2))}")
281+
print(f"Diff between pyt and trt keys: {torch.mean(torch.abs(pyt_keys2[:, :, -2:-1, :] - trt_keys_2))}")
282+
print(f"Diff between pyt and trt keys in cache: {torch.mean(torch.abs(pyt_keys2 - key_cache2[:, :, :-1, :]))}")
283+
print(f"Diff between pyt and trt values: {torch.mean(torch.abs(pyt_values2[:, :, -2:-1, :] - trt_values_2))}")
284+
print(f"Diff between pyt and trt values in cache: {torch.mean(torch.abs(pyt_values2 - value_cache2[:, :, :-1, :]))}")
285+
breakpoint()
286+
elif args.cache == "dynamic":
287+
from dynamic_cache import *
259288
trt_model = compile_torchtrt(model, input_ids, args)
289+
breakpoint()
290+
kv_cache = get_zeroed_kv_cache_inputs(trt_model)
260291
else:
261292
# pyt_logits = model.cuda()(input_ids.clone())
262293
trt_model = compile_torchtrt(model, input_ids, args)
263294
# trt_logits = trt_model(input_ids.clone(), True)
264295
# print(f"Diff between pyt and trt: {torch.mean(torch.abs(pyt_logits - trt_logits))}")
265296
# print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits.logits - trt_logits.logits))}")
266-
if args.kv_cache:
297+
if args.cache == "static":
298+
if args.cudagraph:
299+
# Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
300+
# trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)
301+
torch_tensorrt.runtime.set_cudagraphs_mode(True)
302+
303+
trt_gen_tokens = generate_with_kv_cache(
304+
trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id,
305+
)
306+
307+
if args.benchmark:
308+
trt_timings = time_generate(
309+
generate_with_kv_cache,
310+
trt_model,
311+
input_ids.clone(),
312+
MAX_OUTPUT_SEQ_LENGTH,
313+
tokenizer.eos_token_id,
314+
iterations=args.iterations,
315+
)
316+
elif args.cache == "dynamic":
267317
if args.cudagraph:
268318
# Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
269319
# trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)

0 commit comments

Comments
 (0)