-
Notifications
You must be signed in to change notification settings - Fork 364
adding rotary embedding example, with graph rewrite for complex subgraph #3570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
9d5b3c0
5a2ad50
f5cc275
a90f651
109e5c2
59b8d3b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import time | ||
|
||
import tensorrt as trt | ||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
import torch_tensorrt | ||
from tensor_parallel_initialize_dist import initialize_distributed_env | ||
from torch.distributed._tensor import Shard | ||
from torch.distributed.tensor.parallel import ( | ||
ColwiseParallel, | ||
RowwiseParallel, | ||
parallelize_module, | ||
) | ||
|
||
""" | ||
This example covers the rotary embedding and rotary attention case for tensor parallel | ||
""" | ||
|
||
|
||
def precompute_freqs_cis( | ||
dim: int, end: int, theta: float = 10000.0, n_parallel=1 | ||
) -> torch.Tensor: | ||
"""Precompute the frequency tensor for complex exponentials (cis) with given dimensions. | ||
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' | ||
and the end index 'end'. The 'theta' parameter scales the frequencies. | ||
The returned tensor contains complex values in complex64 data type. | ||
Args: | ||
dim (int): Dimension of the frequency tensor. | ||
end (int): End index for precomputing frequencies. | ||
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. | ||
n_parallel (int, optional): Number of GPUs for parallel computation. Defaults to 1. | ||
Returns: | ||
torch.Tensor: Precomputed frequency tensor with complex exponentials. | ||
""" | ||
freqs = 1.0 / (theta ** (torch.arange(0, dim // n_parallel, 2).float() / dim)) | ||
t = torch.arange(end, device=freqs.device) | ||
freqs = torch.outer(t, freqs).float() | ||
return torch.polar(torch.ones_like(freqs), freqs) | ||
|
||
|
||
def rotary_embedding(xq, xk, dim, freqs_cis=None): | ||
"""This calculates the rotary embedding for the query and key tensors. | ||
Args: | ||
xq (torch.Tensor): Query tensor. | ||
xk (torch.Tensor): Key tensor. | ||
dim (int): Dimension of the query and key tensors. | ||
freqs_cis (torch.Tensor, optional): Precomputed frequency tensor. Defaults to None. | ||
Returns: | ||
tuple: Tuple containing the rotated query and key tensors. | ||
""" | ||
|
||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | ||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | ||
|
||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | ||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | ||
return (xq_out.type_as(xq), xk_out.type_as(xk)) | ||
|
||
|
||
########Tensor Parallel######## | ||
def parallel_rotary_block(rotary_block, tp_mesh): | ||
"""Parallel rotary block for tensor parallel | ||
Args: | ||
rotary_block: Rotary block to parallelize | ||
tp_mesh: Tensor parallel mesh | ||
""" | ||
if tp_mesh.size() <= 1: | ||
return | ||
|
||
plan = { | ||
"wq": ColwiseParallel(), | ||
"wk": ColwiseParallel(), | ||
"wo": RowwiseParallel(output_layouts=Shard(0)), | ||
} | ||
rotary_block.n_parallel = 1 # this is for single GPU, to do remove this hardcode | ||
|
||
parallelize_module(rotary_block, tp_mesh, plan) | ||
|
||
|
||
class RotaryAttention(nn.Module): | ||
def __init__(self, dim: int, seq_len: int): | ||
super().__init__() | ||
self.dim = dim | ||
self.wq = nn.Linear(dim, dim) | ||
self.wk = nn.Linear(dim, dim) | ||
self.wo = nn.Linear(dim, dim) | ||
self.seq_len = seq_len | ||
self.n_parallel = 1 | ||
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) | ||
self.init_weights() | ||
|
||
def _precompute_freqs_cis(self) -> torch.Tensor: | ||
theta = 10000.0 | ||
return precompute_freqs_cis(self.dim, self.seq_len, theta, self.n_parallel) | ||
|
||
def init_weights(self): | ||
with torch.device(self.freqs_cis.device): | ||
self.freqs_cis = self.freqs_cis | ||
|
||
def forward(self, x): | ||
q = self.wq(x) | ||
k = self.wk(x) | ||
freqs_cis = self._precompute_freqs_cis().to(q.device) | ||
q, k = rotary_embedding(q, k, self.dim, freqs_cis=freqs_cis) | ||
return self.wo(q) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import logging | ||
import os | ||
import time | ||
|
||
import torch | ||
import torch_tensorrt | ||
from rotary_embedding import RotaryAttention, parallel_rotary_block | ||
from tensor_parallel_initialize_dist import ( | ||
cleanup_distributed_env, | ||
initialize_distributed_env, | ||
) | ||
|
||
device_mesh, _world_size, _rank, logger = initialize_distributed_env( | ||
"./tensor_parallel_rotary_embedding" | ||
) | ||
|
||
|
||
""" | ||
This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning | ||
""" | ||
|
||
BATCH = 2 | ||
SEQ_LEN = 128 | ||
HEADS = 4 | ||
DIM = 128 | ||
|
||
with torch.no_grad(): | ||
model = RotaryAttention(DIM, SEQ_LEN) | ||
parallel_rotary_block(model, device_mesh) | ||
device = torch.device("cuda", device_mesh.get_rank()) | ||
model.to(device) | ||
x = torch.randn(BATCH, SEQ_LEN, HEADS, DIM).to(device) | ||
|
||
python_result = model(x) | ||
|
||
logger.info("Torch-tensorrt compilation for rotary embedding") | ||
|
||
model = torch.compile(model, backend="torch_tensorrt", options={"debug": True}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove debug=True as it is deprecated |
||
|
||
try: | ||
for i in range(15): | ||
# seeding with dp_rank to ensure identical inputs for TP groups | ||
torch.manual_seed(i) | ||
start = time.time() | ||
output = model(x) | ||
end = time.time() | ||
if i == 0: | ||
logger.info(f"Compilation time is {end-start}") | ||
assert ( | ||
python_result - output | ||
).std() < 0.01, "Compilation result is not correct." | ||
elif _rank == 0: | ||
logger.info(f"Inference time is {end-start}") | ||
except Exception as e: | ||
logger.error(f"Error: {e}") | ||
raise e | ||
finally: | ||
cleanup_distributed_env() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need a try/except and finally here for the example ? And also why are we looping ? If you want to display any results of this block, please use the right formatting https://github.com/pytorch/TensorRT/blob/main/examples/dynamo/torch_export_sam2.py#L282-L297 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The try/except block is for the case where in we check for the inference time improvement over a no of iterations after the graph compilation to see performance improvement without the graph breaks. I am not clear about the formatting part pointed above in the link. Did you point it out because the try loop wont be rendered? I see that the rendering is correct. I could as such remove the loop too There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to loop here ? Is the goal of this example to compare output of rotary embedding block b/w TRT and pytorch model ? |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,10 @@ def getitem_validator(getitem_node: Node, settings: CompilationSettings = None) | |
from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS | ||
|
||
# Getitem nodes can only be converted if their parent node also can | ||
return getitem_node.args[0] in DYNAMO_CONVERTERS | ||
return ( | ||
getitem_node.args[0] in DYNAMO_CONVERTERS | ||
or getitem_node.args[0].op == "get_attr" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this needed ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is needed because the complex tensor is at present a buffer, wherein we extract the real and imag part through input[..., 0], input[..., 1]. This leads to getitem node with arg being a get_attr node.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay. What does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. getitem_node.args[0] returns the node which is the key in the ConverterRegistry.
|
||
) | ||
|
||
|
||
# TODO: Subsequent evaluators should be registered here with their own validators | ||
|
@@ -43,7 +46,10 @@ def generic_evaluator( | |
_LOGGER.debug( | ||
f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}" | ||
) | ||
return target(*args) | ||
from torch._subclasses.fake_tensor import unset_fake_temporarily | ||
|
||
with unset_fake_temporarily(): | ||
return target(*args) | ||
|
||
|
||
def rand_validator(rand_node: Node, settings: CompilationSettings = None) -> bool: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
from ._aten_lowering_pass import * | ||
from ._modify_reshape_complex_nodes import modify_reshape_complex_nodes | ||
from .remove_sym_nodes import remove_sym_nodes | ||
from .repair_input_aliasing import repair_input_aliasing |
Uh oh!
There was an error while loading. Please reload this page.