Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from torchtitan.experiments.compiler_toolkit.graph_utils import (
CompiledModule,
get_inductor_lite_bw_compiler,
get_inductor_lite_fw_compiler,
joint_graph_builder,
)

Expand All @@ -43,11 +45,19 @@ def compiler(name: str, gm: torch.fx.GraphModule, example_inputs):


def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler("fwd_gm", gm, example_inputs)
gm = compiler("fwd_gm", gm, example_inputs)

# TODO: fix inductor size assertion for all_reduce
extra_inductor_config = {"size_asserts": False}
return get_inductor_lite_fw_compiler(extra_inductor_config)(gm, example_inputs)


def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler("bwd_gm", gm, example_inputs)
gm = compiler("bwd_gm", gm, example_inputs)

# TODO: fix inductor size assertion for all_reduce
extra_inductor_config = {"size_asserts": False}
return get_inductor_lite_bw_compiler(extra_inductor_config)(gm, example_inputs)


def annotate_deepseekv3() -> None:
Expand Down
57 changes: 56 additions & 1 deletion torchtitan/experiments/compiler_toolkit/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def joint_graph_builder(
if joint_custom_pass is not None:
joint_custom_pass(joint_with_descriptors)

with tracing(tracing_context):
with tracing(tracing_context), torch._functorch.config.patch(
selective_decompose=True
):
fn = aot_compile_joint_with_descriptors(
joint_with_descriptors, fw_compiler=fw_compiler, bw_compiler=bw_compiler
)
Expand All @@ -122,6 +124,59 @@ def wrapper_fn(args, kwargs):
return wrapper_fn


def get_inductor_lite_fw_compiler(extra_config: Optional[dict] = None):
from torch._inductor import lite_mode_options
from torch._inductor.compile_fx import compile_fx_inner

context = torch._guards.TracingContext.try_get()

if not context or not context.fw_metadata:
logger.warn("No context or fw_metadata available")
static_input_idxs = ()
else:
static_input_idxs = context.fw_metadata.static_input_indices

inductor_config = lite_mode_options
if extra_config:
inductor_config.update(extra_config)

def fw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple):
with torch._inductor.config.patch(inductor_config):
compiled_fn = compile_fx_inner(
gm,
example_inputs,
static_input_idxs=static_input_idxs,
is_backward=False,
)
return compiled_fn

return fw_compiler


def get_inductor_lite_bw_compiler(extra_config: Optional[dict] = None):
from torch._inductor import lite_mode_options
from torch._inductor.compile_fx import compile_fx_inner
from torch._inductor.utils import count_tangents

inductor_config = lite_mode_options
if extra_config:
inductor_config.update(extra_config)

def bw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple):
fixed = count_tangents(gm)

with torch._inductor.config.patch(inductor_config):
compiled_fn = compile_fx_inner(
gm,
example_inputs,
static_input_idxs=list(range(fixed)),
is_backward=True,
)
return compiled_fn

return bw_compiler


class CompiledModule(torch.nn.Module):
def __init__(
self,
Expand Down
8 changes: 6 additions & 2 deletions torchtitan/experiments/compiler_toolkit/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

from torchtitan.experiments.compiler_toolkit.graph_utils import (
CompiledModule,
get_inductor_lite_bw_compiler,
get_inductor_lite_fw_compiler,
joint_graph_builder,
)
from torchtitan.experiments.simple_fsdp.llama3.parallelize import (
Expand Down Expand Up @@ -53,11 +55,13 @@ def compiler(name: str, gm: torch.fx.GraphModule, example_inputs):


def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler("fwd_gm", gm, example_inputs)
gm = compiler("fwd_gm", gm, example_inputs)
return get_inductor_lite_fw_compiler()(gm, example_inputs)


def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler("bwd_gm", gm, example_inputs)
gm = compiler("bwd_gm", gm, example_inputs)
return get_inductor_lite_bw_compiler()(gm, example_inputs)


def validate_flex_attention_annotation(joint_with_descriptors):
Expand Down
Loading