-
Notifications
You must be signed in to change notification settings - Fork 601
add inductor lite mode #2005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
add inductor lite mode #2005
Conversation
|
Some comments:
Maybe test on dsv3 as well? guard the lite mode under some config. Currently CI will fail because it uses |
|
Also FYI I have this WIP PR to config the passes in compiler #2006 |
|
@yiming0416 @SherlockNoMad I verified bitwise equivalence with aot_eager on both llama3-8b and deepseek v3. I will update this PR with a config, after Yiming's #2006 landed. |
|
Also want to check this against SimpleFSDP |
|
@BoyuanFeng To compare with SimpleFSDP, I think you can run the following |
ruisizhang123
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious: 1. does this inductor lite mode incorporate bucketing & overlapping pass for perf? 2. More general: we have many graph passes for region compile, ac, bucketing/overlapping, and lite inductor mode. How we would recommend users to compose all of those passes together when using torch compile?
by default, inductor lite mode turns off all graph passes. The purpose is to give users FULL control on which pass to use. Specifically, users can just turn on bucketing & overlapping pass (e.g., turn on it for post_grad). It should compose well.
All these are fx passes (except inductor lite). In compiler toolkit, users can run as many fx passes as they want, and the last pass would be inductor lite, which codegen and wrap with cudagraph. |
This PR introduces inductor lite mode for opt-in optimizations and numeric correctness guarantees.
Different from default mode that applies all possible fusions, lite mode gives the control back to user and provides guarantee on numeric correctness. Specifically, this mode:
- **Fallback by Default**: Fallback for ALL nodes by default, unless users explicitly mark node for inductor fusion.
- **Selective Decomposition**: Skip decomposition for all nodes except for user marked nodes.
- **Regional inductor compile**
- Skip dead code elimination
- Skip buffer reues
- Skip reorder passes, such as reorder for peak memory, reorder for compute comm overlap, and reorder_for_reducing_graph_partitions.
- Skip all pre-grad, joint-graph, and post-grad passes.
## Example: Flex Attention
```python
import torch
import torch.fx.traceback as fx_traceback
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
def _squared(score, b, h, m, n):
return score * score
def mask_mod(b, h, q, k):
return q >= 0
a, b = 12, 64
block_mask = create_block_mask(mask_mod, None, None, a * b, a * b, device="cuda")
def fn(x):
x = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
x = flex_attention(x, x, x, block_mask=block_mask, score_mod=_squared)
return torch.cos(x)
x = torch.randn(1, 1, a * b, b, dtype=torch.bfloat16, device="cuda", requires_grad=True)
opt_fn = torch.compile(fn, mode="lite", fullgraph=True,)
opt_fn(x)
```
[code diff](https://www.internalfb.com/intern/diffing/?paste_number=2027441476)
[default mode tlp](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpYAzDxX/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000) vs [lite mode tlp](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpnnuh1W/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000)
## Numerics
Inductor lite mode provides bitwise equivalence with `aot_eager` backend on torchtitan llama3-8b and DeepSeek v3. pytorch/torchtitan#2005
close: #167012
Pull Request resolved: #167115
Approved by: https://github.com/ezyang
This PR adds inductor lite mode. Need pytorch/pytorch#167115.
This feature guarantees bitwise equivalence with aot eager. Tested on both llama3-8b and DeepSeek v3.
Command:
llama3-8b
DeepSeek V3