Skip to content
Draft
14 changes: 12 additions & 2 deletions torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from torchtitan.experiments.compiler_toolkit.graph_utils import (
CompiledModule,
get_inductor_lite_bw_compiler,
get_inductor_lite_fw_compiler,
joint_graph_builder,
)

Expand All @@ -43,11 +45,19 @@ def compiler(name: str, gm: torch.fx.GraphModule, example_inputs):


def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler("fwd_gm", gm, example_inputs)
gm = compiler("fwd_gm", gm, example_inputs)

# TODO: fix inductor size assertion for all_reduce
extra_inductor_config = {"size_asserts": False}
return get_inductor_lite_fw_compiler(extra_inductor_config)(gm, example_inputs)


def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler("bwd_gm", gm, example_inputs)
gm = compiler("bwd_gm", gm, example_inputs)

# TODO: fix inductor size assertion for all_reduce
extra_inductor_config = {"size_asserts": False}
return get_inductor_lite_bw_compiler(extra_inductor_config)(gm, example_inputs)


def annotate_deepseekv3() -> None:
Expand Down
57 changes: 56 additions & 1 deletion torchtitan/experiments/compiler_toolkit/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def joint_graph_builder(
if joint_custom_pass is not None:
joint_custom_pass(joint_with_descriptors)

with tracing(tracing_context):
with tracing(tracing_context), torch._functorch.config.patch(
selective_decompose=True
):
fn = aot_compile_joint_with_descriptors(
joint_with_descriptors, fw_compiler=fw_compiler, bw_compiler=bw_compiler
)
Expand All @@ -122,6 +124,59 @@ def wrapper_fn(args, kwargs):
return wrapper_fn


def get_inductor_lite_fw_compiler(extra_config: Optional[dict] = None):
from torch._inductor import lite_mode_options
from torch._inductor.compile_fx import compile_fx_inner

context = torch._guards.TracingContext.try_get()

if not context or not context.fw_metadata:
logger.warn("No context or fw_metadata available")
static_input_idxs = ()
else:
static_input_idxs = context.fw_metadata.static_input_indices

inductor_config = lite_mode_options
if extra_config:
inductor_config.update(extra_config)

def fw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple):
with torch._inductor.config.patch(inductor_config):
compiled_fn = compile_fx_inner(
gm,
example_inputs,
static_input_idxs=static_input_idxs,
is_backward=False,
)
return compiled_fn

return fw_compiler


def get_inductor_lite_bw_compiler(extra_config: Optional[dict] = None):
from torch._inductor import lite_mode_options
from torch._inductor.compile_fx import compile_fx_inner
from torch._inductor.utils import count_tangents

inductor_config = lite_mode_options
if extra_config:
inductor_config.update(extra_config)

def bw_compiler(gm: torch.fx.GraphModule, example_inputs: tuple):
fixed = count_tangents(gm)

with torch._inductor.config.patch(inductor_config):
compiled_fn = compile_fx_inner(
gm,
example_inputs,
static_input_idxs=list(range(fixed)),
is_backward=True,
)
return compiled_fn

return bw_compiler


class CompiledModule(torch.nn.Module):
def __init__(
self,
Expand Down
8 changes: 6 additions & 2 deletions torchtitan/experiments/compiler_toolkit/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

from torchtitan.experiments.compiler_toolkit.graph_utils import (
CompiledModule,
get_inductor_lite_bw_compiler,
get_inductor_lite_fw_compiler,
joint_graph_builder,
)
from torchtitan.experiments.simple_fsdp.llama3.parallelize import (
Expand Down Expand Up @@ -53,11 +55,13 @@ def compiler(name: str, gm: torch.fx.GraphModule, example_inputs):


def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler("fwd_gm", gm, example_inputs)
gm = compiler("fwd_gm", gm, example_inputs)
return get_inductor_lite_fw_compiler()(gm, example_inputs)


def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler("bwd_gm", gm, example_inputs)
gm = compiler("bwd_gm", gm, example_inputs)
return get_inductor_lite_bw_compiler()(gm, example_inputs)


def validate_flex_attention_annotation(joint_with_descriptors):
Expand Down
Loading