Skip to content

Commit 150232d

Browse files
junjiang-labcopybara-github
authored andcommitted
Add impl for aten.lt.Tensor and tfl.less.
PiperOrigin-RevId: 738084860
1 parent 4d3d6d3 commit 150232d

File tree

3 files changed

+12
-0
lines changed

3 files changed

+12
-0
lines changed

ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,8 @@ def _aten_div_tensor_decomp(x, y):
7373
@register_decomp(torch.ops.aten.gt.Tensor)
7474
def _aten_gt_tensor_decomp(x, y):
7575
return torch.ops.tfl.greater(x, y)
76+
77+
78+
@register_decomp(torch.ops.aten.lt.Tensor)
79+
def _aten_lt_tensor_decomp(x, y):
80+
return torch.ops.tfl.less(x, y)

ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ def tfl_greater(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
5959
return torch.gt(x, y)
6060

6161

62+
@custom_op_with_fake("tfl::less")
63+
def tfl_less(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
64+
return torch.lt(x, y)
65+
66+
6267
@custom_op_with_fake("tfl::slice")
6368
def tfl_slice(
6469
input: torch.Tensor, begin: Sequence[int], size: Sequence[int]

ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def _assert_export_and_close(
9090
("aten_div_Tensor_1", torch.ops.aten.div.Tensor, (rnd(torch.float32, (1, 10)), rnd(torch.float32, (10, 1)),), dict()),
9191
("aten_gt_Tensor_0", torch.ops.aten.gt.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
9292
("aten_gt_Tensor_1", torch.ops.aten.gt.Tensor, (rnd(torch.float32, (1, 10)), rnd(torch.float32, (10, 1)),), dict()),
93+
("aten_lt_Tensor_0", torch.ops.aten.lt.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
94+
("aten_lt_Tensor_1", torch.ops.aten.lt.Tensor, (rnd(torch.float32, (1, 10)), rnd(torch.float32, (10, 1)),), dict()),
9395
# fmt: on
9496
# pyformat: enable
9597
)

0 commit comments

Comments
 (0)