Skip to content

adding rotary embedding example, with graph rewrite for complex subgraph #3570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions examples/distributed_inference/rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import time

import tensorrt as trt
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_tensorrt
from tensor_parallel_initialize_dist import initialize_distributed_env
from torch.distributed._tensor import Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)

"""
This example covers the rotary embedding and rotary attention case for tensor parallel
"""


def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0, n_parallel=1
) -> torch.Tensor:
"""Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
end (int): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
n_parallel (int, optional): Number of GPUs for parallel computation. Defaults to 1.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim // n_parallel, 2).float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
return torch.polar(torch.ones_like(freqs), freqs)


def rotary_embedding(xq, xk, dim, freqs_cis=None):
"""This calculates the rotary embedding for the query and key tensors.
Args:
xq (torch.Tensor): Query tensor.
xk (torch.Tensor): Key tensor.
dim (int): Dimension of the query and key tensors.
freqs_cis (torch.Tensor, optional): Precomputed frequency tensor. Defaults to None.
Returns:
tuple: Tuple containing the rotated query and key tensors.
"""

xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return (xq_out.type_as(xq), xk_out.type_as(xk))


########Tensor Parallel########
def parallel_rotary_block(rotary_block, tp_mesh):
"""Parallel rotary block for tensor parallel
Args:
rotary_block: Rotary block to parallelize
tp_mesh: Tensor parallel mesh
"""
if tp_mesh.size() <= 1:
return

plan = {
"wq": ColwiseParallel(),
"wk": ColwiseParallel(),
"wo": RowwiseParallel(output_layouts=Shard(0)),
}
rotary_block.n_parallel = 1 # this is for single GPU, to do remove this hardcode

parallelize_module(rotary_block, tp_mesh, plan)


class RotaryAttention(nn.Module):
def __init__(self, dim: int, seq_len: int):
super().__init__()
self.dim = dim
self.wq = nn.Linear(dim, dim)
self.wk = nn.Linear(dim, dim)
self.wo = nn.Linear(dim, dim)
self.seq_len = seq_len
self.n_parallel = 1
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
self.init_weights()

def _precompute_freqs_cis(self) -> torch.Tensor:
theta = 10000.0
return precompute_freqs_cis(self.dim, self.seq_len, theta, self.n_parallel)

def init_weights(self):
with torch.device(self.freqs_cis.device):
self.freqs_cis = self.freqs_cis

def forward(self, x):
q = self.wq(x)
k = self.wk(x)
freqs_cis = self._precompute_freqs_cis().to(q.device)
q, k = rotary_embedding(q, k, self.dim, freqs_cis=freqs_cis)
return self.wo(q)
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,9 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank, logger


def cleanup_distributed_env():
"""Clean up distributed process group to prevent resource leaks."""
if dist.is_initialized():
dist.destroy_process_group()
58 changes: 58 additions & 0 deletions examples/distributed_inference/tensor_parallel_rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import logging
import os
import time

import torch
import torch_tensorrt
from rotary_embedding import RotaryAttention, parallel_rotary_block
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
initialize_distributed_env,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_rotary_embedding"
)


"""
This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
"""

BATCH = 2
SEQ_LEN = 128
HEADS = 4
DIM = 128

with torch.no_grad():
model = RotaryAttention(DIM, SEQ_LEN)
parallel_rotary_block(model, device_mesh)
device = torch.device("cuda", device_mesh.get_rank())
model.to(device)
x = torch.randn(BATCH, SEQ_LEN, HEADS, DIM).to(device)

python_result = model(x)

logger.info("Torch-tensorrt compilation for rotary embedding")

model = torch.compile(model, backend="torch_tensorrt", options={"debug": True})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove debug=True as it is deprecated


try:
for i in range(15):
# seeding with dp_rank to ensure identical inputs for TP groups
torch.manual_seed(i)
start = time.time()
output = model(x)
end = time.time()
if i == 0:
logger.info(f"Compilation time is {end-start}")
assert (
python_result - output
).std() < 0.01, "Compilation result is not correct."
elif _rank == 0:
logger.info(f"Inference time is {end-start}")
except Exception as e:
logger.error(f"Error: {e}")
raise e
finally:
cleanup_distributed_env()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a try/except and finally here for the example ? And also why are we looping ? If you want to display any results of this block, please use the right formatting https://github.com/pytorch/TensorRT/blob/main/examples/dynamo/torch_export_sam2.py#L282-L297

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The try/except block is for the case where in we check for the inference time improvement over a no of iterations after the graph compilation to see performance improvement without the graph breaks. I am not clear about the formatting part pointed above in the link. Did you point it out because the try loop wont be rendered? I see that the rendering is correct.

I could as such remove the loop too

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to loop here ? Is the goal of this example to compare output of rotary embedding block b/w TRT and pytorch model ?

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import torch.distributed as dist
import torch.nn as nn
import torch_tensorrt
from tensor_parallel_initialize_dist import initialize_distributed_env
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
initialize_distributed_env,
)
from torch.distributed._tensor import Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
Expand Down Expand Up @@ -97,5 +100,4 @@ def forward(self, x):
logger.info(f"Inference time is {end-start}")
finally:
# This cleans up the distributed process group
if dist.is_initialized():
dist.destroy_process_group()
cleanup_distributed_env()
10 changes: 8 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/ops_evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ def getitem_validator(getitem_node: Node, settings: CompilationSettings = None)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS

# Getitem nodes can only be converted if their parent node also can
return getitem_node.args[0] in DYNAMO_CONVERTERS
return (
getitem_node.args[0] in DYNAMO_CONVERTERS
or getitem_node.args[0].op == "get_attr"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed because the complex tensor is at present a buffer, wherein we extract the real and imag part through input[..., 0], input[..., 1]. This leads to getitem node with arg being a get_attr node.
If I don't include the above it leads to graph break for this part in GPU, which is unnecessary if we support get_attr.

graph():
   %_frozen_param3_reshaped : [num_users=2] = get_attr[target=_frozen_param3_reshaped]
   %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_frozen_param3_reshaped, (Ellipsis, 0)), kwargs = {})
   %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%_frozen_param3_reshaped, (Ellipsis, 1)), kwargs = {})
   return (getitem, getitem_1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. What does getitem_node.args[0] return ? Does it return the op name ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getitem_node.args[0] returns the node which is the key in the ConverterRegistry.

def __getitem__(
        self, node: Node
    )
``` with the value being converter, calling convention and dictionary containing supports_dynamic shape, requires_output_allocator.

)


# TODO: Subsequent evaluators should be registered here with their own validators
Expand All @@ -43,7 +46,10 @@ def generic_evaluator(
_LOGGER.debug(
f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}"
)
return target(*args)
from torch._subclasses.fake_tensor import unset_fake_temporarily

with unset_fake_temporarily():
return target(*args)


def rand_validator(rand_node: Node, settings: CompilationSettings = None) -> bool:
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from ._aten_lowering_pass import *
from ._modify_reshape_complex_nodes import modify_reshape_complex_nodes
from .remove_sym_nodes import remove_sym_nodes
from .repair_input_aliasing import repair_input_aliasing
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch_tensorrt.dynamo.utils import is_tegra_platform

from .accumulate_fp32_matmul import accumulate_fp32_matmul
from .complex_graph_rewrite import complex_graph_detection
from .constant_folding import constant_fold
from .fuse_distributed_ops import fuse_distributed_ops
from .fuse_prims_broadcast import fuse_prims_broadcast
Expand All @@ -26,6 +27,7 @@
remove_assert_nodes,
accumulate_fp32_matmul,
remove_num_users_is_0_nodes,
complex_graph_detection,
]

pre_lowering_pass_list = [
Expand Down
Loading
Loading