Skip to content

Commit 1d597b3

Browse files
authored
adding rotary embedding example, with graph rewrite for complex subgraph (#3570)
1 parent 85637b9 commit 1d597b3

File tree

10 files changed

+692
-26
lines changed

10 files changed

+692
-26
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""
2+
.. _rotary_embedding:
3+
4+
Rotary Embedding Implementation for Tensor Parallel Attention
5+
============================================================
6+
7+
This module provides an implementation of rotary positional embeddings (RoPE) for transformer models
8+
with support for tensor parallel distributed inference. Rotary embeddings are used to encode positional
9+
information in transformer attention mechanisms.
10+
"""
11+
12+
import time
13+
14+
import tensorrt as trt
15+
import torch
16+
import torch.distributed as dist
17+
import torch.nn as nn
18+
import torch_tensorrt
19+
from tensor_parallel_initialize_dist import initialize_distributed_env
20+
from torch.distributed._tensor import Shard
21+
from torch.distributed.tensor.parallel import (
22+
ColwiseParallel,
23+
RowwiseParallel,
24+
parallelize_module,
25+
)
26+
27+
"""
28+
This example covers the rotary embedding and rotary attention case for tensor parallel
29+
"""
30+
31+
32+
def precompute_freqs_cis(
33+
dim: int, end: int, theta: float = 10000.0, n_parallel=1
34+
) -> torch.Tensor:
35+
"""Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
36+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
37+
and the end index 'end'. The 'theta' parameter scales the frequencies.
38+
The returned tensor contains complex values in complex64 data type.
39+
Args:
40+
dim (int): Dimension of the frequency tensor.
41+
end (int): End index for precomputing frequencies.
42+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
43+
n_parallel (int, optional): Number of GPUs for parallel computation. Defaults to 1.
44+
Returns:
45+
torch.Tensor: Precomputed frequency tensor with complex exponentials.
46+
"""
47+
freqs = 1.0 / (theta ** (torch.arange(0, dim // n_parallel, 2).float() / dim))
48+
t = torch.arange(end, device=freqs.device)
49+
freqs = torch.outer(t, freqs).float()
50+
return torch.polar(torch.ones_like(freqs), freqs)
51+
52+
53+
def rotary_embedding(xq, xk, dim, freqs_cis=None):
54+
"""This calculates the rotary embedding for the query and key tensors.
55+
Args:
56+
xq (torch.Tensor): Query tensor.
57+
xk (torch.Tensor): Key tensor.
58+
dim (int): Dimension of the query and key tensors.
59+
freqs_cis (torch.Tensor, optional): Precomputed frequency tensor. Defaults to None.
60+
Returns:
61+
tuple: Tuple containing the rotated query and key tensors.
62+
"""
63+
freqs_cis = freqs_cis[None, :, None, :]
64+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
65+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
66+
67+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
68+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
69+
return (xq_out.type_as(xq), xk_out.type_as(xk))
70+
71+
72+
########Tensor Parallel########
73+
def parallel_rotary_block(rotary_block, tp_mesh):
74+
"""Parallel rotary block for tensor parallel
75+
Args:
76+
rotary_block: Rotary block to parallelize
77+
tp_mesh: Tensor parallel mesh
78+
"""
79+
if tp_mesh.size() <= 1:
80+
return
81+
82+
plan = {
83+
"wq": ColwiseParallel(),
84+
"wk": ColwiseParallel(),
85+
"wo": RowwiseParallel(output_layouts=Shard(0)),
86+
}
87+
rotary_block.n_parallel = 1 # this is for single GPU, to do remove this hardcode
88+
89+
parallelize_module(rotary_block, tp_mesh, plan)
90+
91+
92+
class RotaryAttention(nn.Module):
93+
def __init__(self, dim: int, seq_len: int):
94+
super().__init__()
95+
self.dim = dim
96+
self.wq = nn.Linear(dim, dim)
97+
self.wk = nn.Linear(dim, dim)
98+
self.wo = nn.Linear(dim, dim)
99+
self.seq_len = seq_len
100+
self.n_parallel = 1
101+
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
102+
self.init_weights()
103+
104+
def _precompute_freqs_cis(self) -> torch.Tensor:
105+
theta = 10000.0
106+
return precompute_freqs_cis(self.dim, self.seq_len, theta, self.n_parallel)
107+
108+
def init_weights(self):
109+
with torch.device(self.freqs_cis.device):
110+
self.freqs_cis = self.freqs_cis
111+
112+
def forward(self, x):
113+
q = self.wq(x)
114+
k = self.wk(x)
115+
freqs_cis = self._precompute_freqs_cis().to(q.device)
116+
q, k = rotary_embedding(q, k, self.dim, freqs_cis=freqs_cis)
117+
return self.wo(q)

examples/distributed_inference/tensor_parallel_initialize_dist.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
"""
2+
.. _tensor_parallel_initialize_dist:
3+
Tensor Parallel Initialize Distributed Environment
4+
==================================================
5+
6+
This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference.
7+
"""
8+
19
import logging
210
import os
311
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
@@ -65,3 +73,9 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
6573
torch.cuda.set_device(device_id)
6674

6775
return device_mesh, world_size, rank, logger
76+
77+
78+
def cleanup_distributed_env():
79+
"""Clean up distributed process group to prevent resource leaks."""
80+
if dist.is_initialized():
81+
dist.destroy_process_group()
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""
2+
.. _tensor_parallel_rotary_embedding:
3+
Tensor Parallel Rotary Embedding Example
4+
=======================================
5+
6+
This example demonstrates how to use Torch-TensorRT with tensor parallel distributed inference
7+
for models that use rotary positional embeddings (RoPE). It lowers the complex
8+
operations in attention models with rotary embeddings across multiple GPUs.
9+
10+
"""
11+
12+
import logging
13+
import os
14+
import time
15+
16+
import torch
17+
import torch_tensorrt
18+
from rotary_embedding import RotaryAttention, parallel_rotary_block
19+
from tensor_parallel_initialize_dist import (
20+
cleanup_distributed_env,
21+
initialize_distributed_env,
22+
)
23+
24+
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
25+
"./tensor_parallel_rotary_embedding"
26+
)
27+
28+
29+
"""
30+
This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
31+
Command to run with single GPU: mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py
32+
"""
33+
34+
BATCH = 2
35+
SEQ_LEN = 128
36+
HEADS = 4
37+
DIM = 128
38+
39+
with torch.no_grad():
40+
model = RotaryAttention(DIM, SEQ_LEN)
41+
parallel_rotary_block(model, device_mesh)
42+
device = torch.device("cuda", device_mesh.get_rank())
43+
model.to(device)
44+
x = torch.randn(BATCH, SEQ_LEN, HEADS, DIM).to(device)
45+
46+
python_result = model(x)
47+
48+
logger.info("Torch-tensorrt compilation for rotary embedding")
49+
50+
model = torch.compile(model, backend="torch_tensorrt")
51+
52+
torch.manual_seed(0)
53+
start = time.time()
54+
output = model(x)
55+
end = time.time()
56+
logger.info(f"Compilation time is {end-start}")
57+
assert (python_result - output).std() < 0.01, "Compilation result is not correct."
58+
59+
cleanup_distributed_env()

examples/distributed_inference/tensor_parallel_simple_example.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,35 @@
1+
"""
2+
.. _tensor_parallel_simple_example:
3+
4+
Torch Parallel Distributed example for simple model
5+
=========================================
6+
7+
Below example shows how to use Torch-TensorRT backend for distributed inference with tensor parallelism.
8+
9+
This example demonstrates:
10+
- Setting up distributed environment for tensor parallelism
11+
- Model sharding across multiple GPUs
12+
- Compilation with Torch-TensorRT
13+
- Distributed inference execution
14+
15+
Usage
16+
-----
17+
.. code-block:: bash
18+
19+
mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
20+
"""
21+
122
import time
223

324
import tensorrt as trt
425
import torch
526
import torch.distributed as dist
627
import torch.nn as nn
728
import torch_tensorrt
8-
from tensor_parallel_initialize_dist import initialize_distributed_env
29+
from tensor_parallel_initialize_dist import (
30+
cleanup_distributed_env,
31+
initialize_distributed_env,
32+
)
933
from torch.distributed._tensor import Shard
1034
from torch.distributed.tensor.parallel import (
1135
ColwiseParallel,
@@ -18,7 +42,7 @@
1842
)
1943

2044
"""
21-
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
45+
This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
2246
"""
2347

2448

@@ -79,23 +103,15 @@ def forward(self, x):
79103
dynamic=None,
80104
)
81105

82-
try:
83-
for i in range(10):
84-
# For TP, input needs to be same across all TP ranks.
85-
# Setting the random seed is to mimic the behavior of dataloader.
86-
torch.manual_seed(i)
87-
inp = torch.rand(20, 10, device="cuda")
88-
start = time.time()
89-
output = tp_model(inp)
90-
end = time.time()
91-
if i == 0:
92-
logger.info(f"Compilation time is {end-start}")
93-
assert (
94-
python_result - output
95-
).std() < 0.01, "Compilation result is not correct."
96-
elif _rank == 0:
97-
logger.info(f"Inference time is {end-start}")
98-
finally:
99-
# This cleans up the distributed process group
100-
if dist.is_initialized():
101-
dist.destroy_process_group()
106+
# For TP, input needs to be same across all TP ranks.
107+
# Setting the random seed is to mimic the behavior of dataloader.
108+
torch.manual_seed(0)
109+
inp = torch.rand(20, 10, device="cuda")
110+
start = time.time()
111+
output = tp_model(inp)
112+
end = time.time()
113+
logger.info(f"Compilation time is {end - start}")
114+
assert (python_result - output).std() < 0.01, "Result is not correct."
115+
116+
# This cleans up the distributed process group
117+
cleanup_distributed_env()

py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ def getitem_validator(getitem_node: Node, settings: CompilationSettings = None)
2323
from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS
2424

2525
# Getitem nodes can only be converted if their parent node also can
26-
return getitem_node.args[0] in DYNAMO_CONVERTERS
26+
return (
27+
getitem_node.args[0] in DYNAMO_CONVERTERS
28+
or getitem_node.args[0].op == "get_attr"
29+
)
2730

2831

2932
# TODO: Subsequent evaluators should be registered here with their own validators
@@ -43,7 +46,10 @@ def generic_evaluator(
4346
_LOGGER.debug(
4447
f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}"
4548
)
46-
return target(*args)
49+
from torch._subclasses.fake_tensor import unset_fake_temporarily
50+
51+
with unset_fake_temporarily():
52+
return target(*args)
4753

4854

4955
def rand_validator(rand_node: Node, settings: CompilationSettings = None) -> bool:
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from ._aten_lowering_pass import *
2-
from ._modify_reshape_complex_nodes import modify_reshape_complex_nodes
32
from .remove_sym_nodes import remove_sym_nodes
43
from .repair_input_aliasing import repair_input_aliasing

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch_tensorrt.dynamo.utils import is_tegra_platform
77

88
from .accumulate_fp32_matmul import accumulate_fp32_matmul
9+
from .complex_graph_rewrite import complex_graph_detection
910
from .constant_folding import constant_fold
1011
from .fuse_distributed_ops import fuse_distributed_ops
1112
from .fuse_prims_broadcast import fuse_prims_broadcast
@@ -26,6 +27,7 @@
2627
remove_assert_nodes,
2728
accumulate_fp32_matmul,
2829
remove_num_users_is_0_nodes,
30+
complex_graph_detection,
2931
]
3032

3133
pre_lowering_pass_list = [

0 commit comments

Comments
 (0)