diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index 5efd94c2d6..1d80eb3e98 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -72,6 +72,10 @@ def validate_flex_attention_annotation(joint_with_descriptors): def annotate_llama() -> None: from torchtitan.models.attention import FlexAttentionWrapper + from torchtitan.models.llama3.model.model import TransformerBlock + + # Mark TransformerBlock.forward as nested_compile_region + TransformerBlock.forward = torch.compiler.nested_compile_region(TransformerBlock.forward) # annotate flex_attention with compile_with_inductor FlexAttentionWrapper.forward = annotate_fn(