Skip to content

Conversation

@BoyuanFeng
Copy link
Contributor

@BoyuanFeng BoyuanFeng commented Nov 8, 2025

This PR adds inductor lite mode. Need pytorch/pytorch#167115.

This feature guarantees bitwise equivalence with aot eager. Tested on both llama3-8b and DeepSeek v3.

Command:

llama3-8b

NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --debug.seed=0 --debug.deterministic

DeepSeek V3

NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --debug.seed=0 --debug.deterministic

@BoyuanFeng BoyuanFeng requested a review from yiming0416 November 8, 2025 01:54
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 8, 2025
@BoyuanFeng BoyuanFeng marked this pull request as draft November 8, 2025 01:56
@yiming0416
Copy link
Contributor

Some comments:

pre-commit run --all-files to fix lint.

Maybe test on dsv3 as well?

guard the lite mode under some config.

Currently CI will fail because it uses pytorch-nightly. So need to land your pytorch PR first.

@yiming0416
Copy link
Contributor

Also FYI I have this WIP PR to config the passes in compiler #2006

@BoyuanFeng
Copy link
Contributor Author

BoyuanFeng commented Nov 10, 2025

@yiming0416 @SherlockNoMad I verified bitwise equivalence with aot_eager on both llama3-8b and deepseek v3.

I will update this PR with a config, after Yiming's #2006 landed.

@ezyang
Copy link
Contributor

ezyang commented Nov 10, 2025

Also want to check this against SimpleFSDP

@ezyang ezyang requested a review from ruisizhang123 November 10, 2025 14:59
@yiming0416
Copy link
Contributor

@BoyuanFeng To compare with SimpleFSDP, I think you can run the following

NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name simple_fsdp.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --debug.seed=0 --debug.deterministic

Copy link
Contributor

@ruisizhang123 ruisizhang123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious: 1. does this inductor lite mode incorporate bucketing & overlapping pass for perf? 2. More general: we have many graph passes for region compile, ac, bucketing/overlapping, and lite inductor mode. How we would recommend users to compose all of those passes together when using torch compile?

@BoyuanFeng
Copy link
Contributor Author

BoyuanFeng commented Nov 11, 2025

cc @ruisizhang123

  1. does this inductor lite mode incorporate bucketing & overlapping pass for perf?

by default, inductor lite mode turns off all graph passes. The purpose is to give users FULL control on which pass to use. Specifically, users can just turn on bucketing & overlapping pass (e.g., turn on it for post_grad). It should compose well.

  1. More general: we have many graph passes for region compile, ac, bucketing/overlapping, and lite inductor mode. How we would recommend users to compose all of those passes together when using torch compile?

All these are fx passes (except inductor lite). In compiler toolkit, users can run as many fx passes as they want, and the last pass would be inductor lite, which codegen and wrap with cudagraph.

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Nov 12, 2025
This PR introduces inductor lite mode for opt-in optimizations and numeric correctness guarantees.

Different from default mode that applies all possible fusions, lite mode gives the control back to user and provides guarantee on numeric correctness. Specifically, this mode:

- **Fallback by Default**: Fallback for ALL nodes by default, unless users explicitly mark node for inductor fusion.
- **Selective Decomposition**: Skip decomposition for all nodes except for user marked nodes.
- **Regional inductor compile**
- Skip dead code elimination
- Skip buffer reues
- Skip reorder passes, such as reorder for peak memory, reorder for compute comm overlap, and reorder_for_reducing_graph_partitions.
- Skip all pre-grad, joint-graph, and post-grad passes.

## Example: Flex Attention

```python
import torch
import torch.fx.traceback as fx_traceback
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

def _squared(score, b, h, m, n):
    return score * score

def mask_mod(b, h, q, k):
    return q >= 0

a, b = 12, 64
block_mask = create_block_mask(mask_mod, None, None, a * b, a * b, device="cuda")

def fn(x):
    x = torch.sin(x)
    with fx_traceback.annotate({"compile_with_inductor": 0}):
        x = flex_attention(x, x, x, block_mask=block_mask, score_mod=_squared)
    return torch.cos(x)

x = torch.randn(1, 1, a * b, b, dtype=torch.bfloat16, device="cuda", requires_grad=True)

opt_fn = torch.compile(fn, mode="lite", fullgraph=True,)
opt_fn(x)
```

[code diff](https://www.internalfb.com/intern/diffing/?paste_number=2027441476)

[default mode tlp](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpYAzDxX/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000) vs [lite mode tlp](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpnnuh1W/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000)

## Numerics

Inductor lite mode provides bitwise equivalence with `aot_eager` backend on torchtitan llama3-8b and DeepSeek v3. pytorch/torchtitan#2005

close: #167012

Pull Request resolved: #167115
Approved by: https://github.com/ezyang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants