Skip to content

Commit 647793c

Browse files
committed
Added warning and changed the test cases
1 parent 68ca2c8 commit 647793c

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from torch._dynamo.backends.common import aot_autograd
1111
from torch._dynamo.utils import detect_fake_mode
1212
from torch._functorch.aot_autograd import aot_export_joint_simple
13-
from torch.distributed.tensor import DTensor
1413
from torch_tensorrt.dynamo import CompilationSettings
1514
from torch_tensorrt.dynamo._compiler import compile_module
1615
from torch_tensorrt.dynamo.lowering import (
@@ -89,6 +88,11 @@ def aot_torch_tensorrt_aten_backend(
8988
logger.warning(
9089
"It is recommended to run the model with use_distributed_mode_trace = True since there are distributed tensors in the input which is not supported in aot_export_joint_simple"
9190
)
91+
92+
if settings.offload_module_to_cpu:
93+
logger.warning(
94+
"`offload_module_to_cpu` is not supported for `torch_compile` backend."
95+
)
9296
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
9397

9498

tests/py/dynamo/models/test_models.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,12 @@ def test_resnet18_cpu_offload(ir):
7979
}
8080

8181
trt_mod = torchtrt.compile(model, **compile_spec)
82-
assertions.assertTrue(
83-
get_model_device(model).type == "cpu",
84-
msg="Model should be offloaded to CPU",
85-
)
86-
model.cuda()
82+
if ir == "dynamo":
83+
assertions.assertTrue(
84+
get_model_device(model).type == "cpu",
85+
msg="Model should be offloaded to CPU",
86+
)
87+
model.cuda()
8788
cos_sim = cosine_similarity(model(input), trt_mod(input))
8889
assertions.assertTrue(
8990
cos_sim > COSINE_THRESHOLD,
@@ -286,11 +287,12 @@ def test_bert_base_uncased_cpu_offload(ir):
286287
"offload_module_to_cpu": True,
287288
}
288289
trt_mod = torchtrt.compile(model, **compile_spec)
289-
assertions.assertTrue(
290-
get_model_device(model).type == "cpu",
291-
msg="Model should be offloaded to CPU",
292-
)
293-
model.cuda()
290+
if ir == "dynamo":
291+
assertions.assertTrue(
292+
get_model_device(model).type == "cpu",
293+
msg="Model should be offloaded to CPU",
294+
)
295+
model.cuda()
294296

295297
model_outputs = model(input, input2)
296298
trt_model_outputs = trt_mod(input, input2)

0 commit comments

Comments
 (0)