Skip to content

Commit 5b3e714

Browse files
committed
Added warning and changed the test cases
1 parent 75dc7bb commit 5b3e714

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from torch._dynamo.backends.common import aot_autograd
1111
from torch._dynamo.utils import detect_fake_mode
1212
from torch._functorch.aot_autograd import aot_export_joint_simple
13-
from torch.distributed.tensor import DTensor
1413
from torch_tensorrt.dynamo import CompilationSettings
1514
from torch_tensorrt.dynamo._compiler import compile_module
1615
from torch_tensorrt.dynamo.lowering import (
@@ -89,6 +88,11 @@ def aot_torch_tensorrt_aten_backend(
8988
logger.warning(
9089
"It is recommended to run the model with use_distributed_mode_trace = True since there are distributed tensors in the input which is not supported in aot_export_joint_simple"
9190
)
91+
92+
if settings.offload_module_to_cpu:
93+
logger.warning(
94+
"`offload_module_to_cpu` is not supported for `torch_compile` backend."
95+
)
9296
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
9397

9498

tests/py/dynamo/models/test_models.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,12 @@ def test_resnet18_cpu_offload(ir):
7979
}
8080

8181
trt_mod = torchtrt.compile(model, **compile_spec)
82-
assertions.assertTrue(
83-
get_model_device(model).type == "cpu",
84-
msg="Model should be offloaded to CPU",
85-
)
86-
model.cuda()
82+
if ir == "dynamo":
83+
assertions.assertTrue(
84+
get_model_device(model).type == "cpu",
85+
msg="Model should be offloaded to CPU",
86+
)
87+
model.cuda()
8788
cos_sim = cosine_similarity(model(input), trt_mod(input))
8889
assertions.assertTrue(
8990
cos_sim > COSINE_THRESHOLD,
@@ -287,11 +288,12 @@ def test_bert_base_uncased_cpu_offload(ir):
287288
"offload_module_to_cpu": True,
288289
}
289290
trt_mod = torchtrt.compile(model, **compile_spec)
290-
assertions.assertTrue(
291-
get_model_device(model).type == "cpu",
292-
msg="Model should be offloaded to CPU",
293-
)
294-
model.cuda()
291+
if ir == "dynamo":
292+
assertions.assertTrue(
293+
get_model_device(model).type == "cpu",
294+
msg="Model should be offloaded to CPU",
295+
)
296+
model.cuda()
295297

296298
model_outputs = model(input, input2)
297299
trt_model_outputs = trt_mod(input, input2)

0 commit comments

Comments
 (0)