Skip to content

Commit 640e96d

Browse files
committed
revert linear converter
1 parent 962fb48 commit 640e96d

File tree

7 files changed

+141
-10
lines changed

7 files changed

+141
-10
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3579,3 +3579,23 @@ def aten_ops_nonzero(
35793579
name,
35803580
args[0],
35813581
)
3582+
3583+
3584+
@dynamo_tensorrt_converter(torch.ops.aten.linear.default, supports_dynamic_shapes=True)
3585+
@dynamo_tensorrt_converter(torch.ops.aten.linear, supports_dynamic_shapes=True)
3586+
def aten_ops_linear(
3587+
ctx: ConversionContext,
3588+
target: Target,
3589+
args: Tuple[Argument, ...],
3590+
kwargs: Dict[str, Argument],
3591+
name: str,
3592+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3593+
return impl.linear.linear(
3594+
ctx,
3595+
target,
3596+
SourceIR.ATEN,
3597+
name,
3598+
input=args[0],
3599+
weight=args[1],
3600+
bias=args_bounds_check(args, 2, None),
3601+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,5 @@
3131
unary,
3232
unsqueeze,
3333
upsample,
34+
linear,
3435
)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from typing import Optional, Union
2+
3+
import numpy as np
4+
import tensorrt as trt
5+
import torch
6+
from torch.fx.node import Target
7+
from torch_tensorrt.dynamo.conversion import impl
8+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
9+
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
10+
from torch_tensorrt.fx.types import TRTTensor
11+
12+
13+
def linear(
14+
ctx: ConversionContext,
15+
target: Union[Target, str],
16+
source_ir: Optional[SourceIR],
17+
name: str,
18+
input: TRTTensor,
19+
weight: Union[TRTTensor, torch.Tensor, np.ndarray],
20+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
21+
) -> TRTTensor:
22+
# Process weight terms
23+
if not isinstance(weight, (TRTTensor, torch.Tensor, np.ndarray)):
24+
raise RuntimeError(
25+
f"Linear layer {name} has weight of type {type(weight)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray],"
26+
)
27+
elif isinstance(weight, (torch.Tensor, np.ndarray)):
28+
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
29+
30+
# Process bias terms
31+
if bias is not None and not isinstance(bias, (TRTTensor, torch.Tensor, np.ndarray)):
32+
raise RuntimeError(
33+
f"Linear layer {name} has bias of type {type(bias)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray],"
34+
)
35+
elif isinstance(bias, (torch.Tensor, np.ndarray)):
36+
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
37+
38+
# add IMatrixMultiplyLayer
39+
out = impl.matmul.matrix_multiply(
40+
ctx,
41+
target,
42+
source_ir,
43+
name,
44+
input,
45+
weight,
46+
input_matrix_op=trt.MatrixOperation.NONE,
47+
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
48+
)
49+
50+
if bias is not None:
51+
# add bias
52+
out = impl.elementwise.add(ctx, target, source_ir, name, out, bias)
53+
54+
return out

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@
171171
aten.upsample_bilinear2d.vec,
172172
aten.upsample_trilinear3d.vec,
173173
aten.upsample_bicubic2d.vec,
174+
aten.linear.default,
174175
}
175176

176177

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes
1818
from .repair_input_as_output import repair_input_as_output
1919
from .replace_max_pool_with_indices import replace_max_pool_with_indices
20+
from .lower_linear import lower_linear
2021

2122
post_lowering_pass_list = [
2223
remove_input_alias_fixing_clones,
@@ -28,6 +29,7 @@
2829
accumulate_fp32_matmul,
2930
remove_num_users_is_0_nodes,
3031
complex_graph_detection,
32+
lower_linear,
3133
]
3234

3335
pre_lowering_pass_list = [
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import logging
2+
3+
import torch
4+
from torch_tensorrt.dynamo._settings import CompilationSettings
5+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
6+
clean_up_graph_after_modifications,
7+
)
8+
from torch_tensorrt.dynamo.utils import get_metadata, set_metadata
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def lower_linear(
14+
gm: torch.fx.GraphModule, settings: CompilationSettings
15+
) -> torch.fx.GraphModule:
16+
"""Replace aten.linear with an equivalent implementation which can be easily converted to TRT"""
17+
orig_op = torch.ops.aten.addmm.default
18+
replacement_op = torch.ops.aten.linear.default
19+
20+
# Original graph
21+
def orig(
22+
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
23+
) -> torch.Tensor:
24+
W_T = torch.ops.aten.permute.default(weight, [1, 0])
25+
out = orig_op(bias, input, W_T)
26+
return out
27+
28+
# Replacement graph
29+
def replacement(
30+
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
31+
) -> torch.Tensor:
32+
return replacement_op(input, weight, bias)
33+
34+
metadata = get_metadata(gm, orig_op)
35+
replaced_nodes = torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement)
36+
37+
if len(replaced_nodes) > 0:
38+
gm = clean_up_graph_after_modifications(gm)
39+
set_metadata(gm, replacement_op, metadata)
40+
logger.debug(f"Graph after lowering linear:\n{gm.graph}")
41+
42+
return gm

tools/perf/perf_run.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size):
282282
cache_built_engines=params.get("cache_built_engines", False),
283283
reuse_cached_engines=params.get("reuse_cached_engines", False),
284284
use_python_runtime=params.get("use_python_runtime", False),
285-
optimization_level=params.get("optimization_level", 5),
285+
optimization_level=params.get("optimization_level", 3),
286286
)
287287
end_compile = timeit.default_timer()
288288
compile_time_s = end_compile - start_compile
@@ -441,21 +441,26 @@ def run_tensorrt(
441441
if params["is_trt_engine"]:
442442
serialized_engine = model
443443
else:
444-
# Export an ONNX model and convert to TRT
445-
torch.onnx.export(model.eval().cuda(), tuple(input_tensors), "./tmp.onnx")
444+
if params["onnx"]:
445+
onnx_path = params["onnx"]
446+
else:
447+
# Export an ONNX model and convert to TRT
448+
onnx_path = "./onnx-trt.onnx"
449+
exp_program = torch.export.export(model.eval().cuda(), tuple(input_tensors))
450+
torch.onnx.export(exp_program, tuple(input_tensors), onnx_path)
446451
builder = trt.Builder(logger)
447452
network = builder.create_network(
448453
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
449454
)
450455
parser = trt.OnnxParser(network, logger)
451-
success = parser.parse_from_file("./tmp.onnx")
456+
success = parser.parse_from_file(onnx_path)
452457
if not success:
453458
raise ValueError("ONNX conversion failed")
454459

455460
config = builder.create_builder_config()
456461
if precision == "fp16":
457462
config.set_flag(trt.BuilderFlag.FP16)
458-
config.builder_optimization_level = params.get("optimization_level", 5)
463+
config.builder_optimization_level = params.get("optimization_level", 3)
459464
start_compile = timeit.default_timer()
460465
serialized_engine = builder.build_serialized_network(network, config)
461466
end_compile = timeit.default_timer()
@@ -561,7 +566,7 @@ def run(
561566
print("int8 precision expects calibration cache file for inference")
562567
return False
563568

564-
if (model is None) and (backend in ("tensorrt", "ts_trt", "all")):
569+
if (model is None) and (backend in ("ts_trt", "all")):
565570
warnings.warn(
566571
f"Requested backend {backend} without specifying a TorchScript Model, "
567572
+ "skipping this backend"
@@ -585,7 +590,7 @@ def run(
585590
batch_size,
586591
)
587592
run_tensorrt(
588-
model,
593+
model_torch,
589594
input_tensors,
590595
params,
591596
precision,
@@ -606,7 +611,7 @@ def run(
606611
)
607612
elif backend == "tensorrt":
608613
run_tensorrt(
609-
model,
614+
model_torch,
610615
input_tensors,
611616
params,
612617
precision,
@@ -641,6 +646,12 @@ def run(
641646
default="",
642647
help="Name of torch model file",
643648
)
649+
arg_parser.add_argument(
650+
"--onnx",
651+
type=str,
652+
default="",
653+
help="ONNX model file which helps bypass the step of exporting ONNX from torchscript model. If this argument is provided, the ONNX will be directly converted to TRT engine",
654+
)
644655
arg_parser.add_argument(
645656
"--inputs",
646657
type=str,
@@ -683,7 +694,7 @@ def run(
683694
arg_parser.add_argument(
684695
"--optimization_level",
685696
type=int,
686-
default=5,
697+
default=3,
687698
help="Builder optimization level for TensorRT",
688699
)
689700
arg_parser.add_argument(
@@ -767,7 +778,7 @@ def run(
767778
)
768779

769780
backends = parse_backends(params["backends"])
770-
if ("dynamo" in backends or "torch_compile" in backends) and (model_torch is None):
781+
if any(backend in ["dynamo", "torch_compile", "tensorrt"] for backend in backends) and (model_torch is None):
771782
raise ValueError(
772783
"No Pytorch model (nn.Module) is provided for torchdynamo compilation. Please provide a pytorch model using --model_torch argument"
773784
)

0 commit comments

Comments
 (0)