Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c7ecb1e
Add torch2.9 in regression tests
jainapurva Nov 7, 2025
e9f94ba
Update torch version to 2.9.1 in regression tests
jainapurva Nov 12, 2025
886f0a6
Update torch version from 2.7.0 to 2.7.1
jainapurva Nov 12, 2025
1a9a13f
Move dyn_int8_act_int4_wei_cpu_layout to prototype/dtypes (#3299)
jainapurva Nov 7, 2025
677ed0c
Skip quantization when channels_out / channels_in are not multiple of…
jerryzh168 Nov 7, 2025
02ecbb7
[mxfp8 moe training][BE] add docs showing equivalent convergence to b…
danielvegamyhre Nov 8, 2025
e4ecec0
Move marlin_qqq_tensor to prototype/dtypes (#3307)
jainapurva Nov 8, 2025
865583b
Enable `PerRow(axis)` to support axes other than `-1` (#3303)
vkuzo Nov 10, 2025
2c10943
Remove old TORCH_VERSION variables (#3146)
andrewor14 Nov 10, 2025
36e8d0b
Add per tensor fp8 conv2d support (#3315)
jerryzh168 Nov 10, 2025
bab6ce5
Pin pytest==8.4.2 (#3321)
huydhn Nov 11, 2025
8bce9b1
Update common used toy linear model (#3275)
namgyu-youn Nov 11, 2025
4a102c2
Use conda libgcc-ng 11.2 (#3327)
atalman Nov 11, 2025
5c3e652
Move gemlite layout to prototype/dtypes (#3313)
jainapurva Nov 12, 2025
7213f81
Move uintx_layout to prototype/dtypes (#3316)
jainapurva Nov 12, 2025
726607d
Add __str__ to FqnToConfig to make printing more readable (#3323)
jcaip Nov 12, 2025
42fc6bd
Add support for e2e benchmark for conv2d/conv3d (#3329)
jerryzh168 Nov 12, 2025
8c37568
Move floatx_tensor_core_layout to prototype/dtypes (#3317)
jainapurva Nov 12, 2025
d7b537b
Use conda libgcc-ng 11.2 for nightly tests (#3326)
jainapurva Nov 13, 2025
9ba0a3f
Fix tests
jainapurva Nov 13, 2025
d2192b1
Merge origin/main into add_torch2.9_tests
jainapurva Nov 13, 2025
3884806
Add a condition to run only if torch 2.9
jainapurva Nov 14, 2025
a543b2a
Update utils.py
jainapurva Nov 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
dev-requirements-overrides: ""
- name: CUDA 2.7
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: 'torch==2.7.0'
torch-spec: 'torch==2.7.1'
gpu-arch-type: "cuda"
gpu-arch-version: "12.6"
dev-requirements-overrides: ""
Expand All @@ -77,6 +77,12 @@ jobs:
gpu-arch-type: "cuda"
gpu-arch-version: "12.6"
dev-requirements-overrides: ""
- name: CUDA 2.9
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: 'torch==2.9.1'
gpu-arch-type: "cuda"
gpu-arch-version: "12.6"
dev-requirements-overrides: ""

- name: CPU 2.6
runs-on: linux.4xlarge
Expand All @@ -86,7 +92,7 @@ jobs:
dev-requirements-overrides: ""
- name: CPU 2.7
runs-on: linux.4xlarge
torch-spec: 'torch==2.7.0 --index-url https://download.pytorch.org/whl/cpu'
torch-spec: 'torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu'
gpu-arch-type: "cpu"
gpu-arch-version: ""
dev-requirements-overrides: ""
Expand All @@ -96,6 +102,12 @@ jobs:
gpu-arch-type: "cpu"
gpu-arch-version: ""
dev-requirements-overrides: ""
- name: CPU 2.9
runs-on: linux.4xlarge
torch-spec: 'torch==2.9.1 --index-url https://download.pytorch.org/whl/cpu'
gpu-arch-type: "cpu"
gpu-arch-version: ""
dev-requirements-overrides: ""

uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
Expand Down
10 changes: 10 additions & 0 deletions torchao/quantization/pt2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,16 @@ def _get_aten_graph_module_for_pattern(
):
aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr]

if torch.__version__.startswith("2.9"):
# PyTorch 2.9 adds _guards_fn nodes to exported graphs.
# These have errors only on torch 2.9.0 and 2.9.1
for node in list(aten_pattern.graph.nodes): # type: ignore[union-attr]
if node.op == "call_module" and node.name == "_guards_fn":
aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr]
# Also remove the _guards_fn module from the graph module if it exists
if hasattr(aten_pattern, "_guards_fn"):
delattr(aten_pattern, "_guards_fn")

aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr]
aten_pattern.recompile() # type: ignore[operator]

Expand Down
Loading