Skip to content

Commit d3ccda8

Browse files
committed
add blog post
1 parent 3b079ec commit d3ccda8

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

docs/source/en/optimization/fp16.md

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,18 @@ Feel free to open an issue if dynamic compilation doesn't work as expected for a
174174

175175
### Regional compilation
176176

177+
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by only compiling the *small and frequently-repeated block(s)* of a model - typically a transformer layer - and enables reusing compiled artifacts for every subsequent occurrence.
178+
For many diffusion architectures, this delivers the same runtime speed-ups as full-graph compilation and reduces compile time by 8–10x.
177179

178-
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
179-
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **8–10 ×**.
180+
There are two implementations of regional compilation.
180181

181-
To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
182+
- The Diffusers version, [`~ModelMixin.compile_repeated_blocks`], is more explicit and is easier to customize.
183+
- The Accelerate version, [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78), automatically selects which regions to compile and is less customizable. It is ideal for fast experiments.
184+
185+
<hfoptions id="regional-compilation">
186+
<hfoption id="compile_repeated_blocks">
187+
188+
Use the [`~ModelMixin.compile_repeated_blocks`] method, a helper that wraps `torch.compile`, on any component such as the transformer model as shown below.
182189

183190
```py
184191
# pip install -U diffusers
@@ -194,19 +201,20 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
194201
pipe.unet.compile_repeated_blocks(fullgraph=True)
195202
```
196203

197-
To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
198-
204+
To enable regional compilation for a new model, add a `_repeated_blocks` attribute to a model class containing the class names (as strings) of the blocks you want to compile.
199205

200206
```py
201207
class MyUNet(ModelMixin):
202208
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
203209
```
204210

205-
For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
206-
207-
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
211+
> [!TIP]
212+
> For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
208213
214+
</hfoption>
215+
<hfoption id="compile_regions">
209216

217+
There is also a [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method in [Accelerate](https://huggingface.co/docs/accelerate/index) that automatically selects candidate blocks in a model to compile. The remaining graph is compiled separately. This is useful for quick experiments because there aren't as many options for you to set which blocks to compile or adjust compilation flags.
210218

211219
```py
212220
# pip install -U accelerate
@@ -219,8 +227,11 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
219227
).to("cuda")
220228
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
221229
```
222-
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
223230

231+
[`~ModelMixin.compile_repeated_blocks`] is intentionally explicit. List the blocks to repeat in `_repeated_blocks` and the helper only compiles those blocks. It offers predictable behavior and easy reasoning about cache reuse in one line of code.
232+
233+
</hfoption>
234+
</hfoptions>
224235

225236
### Graph breaks
226237

@@ -296,3 +307,7 @@ An input is projected into three subspaces, represented by the projection matric
296307
```py
297308
pipeline.fuse_qkv_projections()
298309
```
310+
311+
## Resources
312+
313+
Read the [Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/) blog post to learn more about how you can combine all of these optimizations with [TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html) and [AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html) for a ~2.5x speedup.

0 commit comments

Comments
 (0)