Skip to content

Add Flux fp4 support #3657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions examples/apps/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ python flux_demo.py

### Using Different Precision Modes

- FP4 mode:
```bash
python flux_demo.py --dtype fp4
```

- FP8 mode:
```bash
python flux_demo.py --dtype fp8
Expand Down
56 changes: 42 additions & 14 deletions examples/apps/flux_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@
from diffusers import FluxPipeline
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel

# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo"))
from register_sdpa import *

DEVICE = "cuda:0"


Expand All @@ -24,8 +20,23 @@ def compile_model(
) -> tuple[
FluxPipeline, FluxTransformer2DModel, torch_tensorrt.MutableTorchTensorRTModule
]:

if args.dtype == "fp8":
use_explicit_typing = False
if args.use_sdpa:
# currently use sdpa is not working correctly with flux model, so we don't use it
# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo"))
import register_sdpa

if args.dtype == "fp4":
use_explicit_typing = True
enabled_precisions = {torch.float4_e2m1fn_x2}
ptq_config = mtq.NVFP4_DEFAULT_CFG
if args.fp4_mha:
from modelopt.core.torch.quantization.config import NVFP4_FP8_MHA_CONFIG

ptq_config = NVFP4_FP8_MHA_CONFIG

elif args.dtype == "fp8":
enabled_precisions = {torch.float8_e4m3fn, torch.float16}
ptq_config = mtq.FP8_DEFAULT_CFG

Expand Down Expand Up @@ -107,26 +118,33 @@ def forward_loop(mod):
"enabled_precisions": enabled_precisions,
"truncate_double": True,
"min_block_size": 1,
"use_python_runtime": True,
"use_python_runtime": False,
"immutable_weights": False,
"offload_module_to_cpu": True,
"offload_module_to_cpu": args.low_vram_mode,
"use_explicit_typing": use_explicit_typing,
}
if args.low_vram_mode:
pipe.remove_all_hooks()
pipe.enable_sequential_cpu_offload()
remove_hook_from_module(pipe.transformer, recurse=True)
pipe.transformer.to(DEVICE)

trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
if dynamic_shapes:
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
pipe.transformer = trt_gm

seed = 42
image = pipe(
"Test",
[
"enchanted winter forest, soft diffuse light on a snow-filled day, serene nature scene, the forest is illuminated by the snow"
],
output_type="pil",
num_inference_steps=2,
num_inference_steps=30,
num_images_per_prompt=batch_size,
generator=torch.Generator("cuda").manual_seed(seed),
).images
print(f"generated {len(image)} images")
image[0].save("forest.png")

torch.cuda.empty_cache()

Expand Down Expand Up @@ -242,12 +260,22 @@ def main(args):
parser = argparse.ArgumentParser(
description="Run Flux quantization with different dtypes"
)

parser.add_argument(
"--use_sdpa",
action="store_true",
help="Use sdpa",
default=False,
)
parser.add_argument(
"--dtype",
choices=["fp8", "int8", "fp16"],
choices=["fp4", "fp8", "int8", "fp16"],
default="fp16",
help="Select the data type to use (fp8 or int8 or fp16)",
help="Select the data type to use (fp4 or fp8 or int8 or fp16)",
)
parser.add_argument(
"--fp4_mha",
action="store_true",
help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_DEFAULT_CFG",
)
parser.add_argument(
"--low_vram_mode",
Expand Down
10 changes: 6 additions & 4 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,11 @@ def cross_compile_for_windows(

if use_explicit_typing:
if len(enabled_precisions) != 1 or not any(
x in enabled_precisions for x in {torch.float32, dtype.f32}
x in enabled_precisions
for x in {torch.float32, dtype.f32, torch.float4_e2m1fn_x2, dtype.f4}
):
raise AssertionError(
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: {_defaults.ENABLED_PRECISIONS}). enabled_precisions should not be used when use_explicit_typing=True"
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True"
)

if use_fp32_acc:
Expand Down Expand Up @@ -591,10 +592,11 @@ def compile(

if use_explicit_typing:
if len(enabled_precisions) != 1 or not any(
x in enabled_precisions for x in {torch.float32, dtype.f32}
x in enabled_precisions
for x in {torch.float32, dtype.f32, torch.float4_e2m1fn_x2, dtype.f4}
):
raise AssertionError(
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: {_defaults.ENABLED_PRECISIONS}). enabled_precisions should not be used when use_explicit_typing=True"
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True"
)

if use_fp32_acc:
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
import modelopt.torch.quantization as mtq

assert torch.ops.tensorrt.quantize_op.default
assert torch.ops.tensorrt.dynamic_block_quantize_op.default
self.quantization_ops.add(torch.ops.tensorrt.quantize_op.default)
self.quantization_ops.add(
torch.ops.tensorrt.dynamic_block_quantize_op.default
)
except Exception as e:
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def export_fn() -> torch.export.ExportedProgram:
# Check if any quantization precision is enabled
if self.enabled_precisions and any(
precision in self.enabled_precisions
for precision in (torch.float8_e4m3fn, torch.int8)
for precision in (torch.float8_e4m3fn, torch.int8, torch.float4_e2m1fn_x2)
):
try:
from modelopt.torch.quantization.utils import export_torch_mode
Expand Down
60 changes: 53 additions & 7 deletions tools/perf/Flux/flux_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,47 @@
import sys
from time import time

import torch

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../examples/apps"))
from flux_demo import compile_model


def profile(pipe, prompt, inference_step, batch_size=1):
print(f"Running torch profiler with {inference_step=} {batch_size=}")
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
) as prof:
with torch.profiler.record_function("model_inference"):
pipe(
prompt,
output_type="pil",
num_inference_steps=inference_step,
num_images_per_prompt=batch_size,
).images
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))


def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1):
print(f"Running warmup with {batch_size=} {inference_step=} iterations=10")
# warmup
for i in range(10):
start = time()
images = pipe(
prompt,
output_type="pil",
num_inference_steps=inference_step,
num_images_per_prompt=batch_size,
).images
print(
f"Warmup {i} done in {time() - start} seconds, with {batch_size=} {inference_step=}, generated {len(images)} images"
)

# actual benchmark
print(f"Running benchmark with {batch_size=} {inference_step=} {iterations=}")
start = time()
for i in range(iterations):
image = pipe(
Expand All @@ -18,32 +53,43 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1):
num_images_per_prompt=batch_size,
).images
end = time()

print(f"Batch Size: {batch_size}")
print("Time Elapse for", iterations, "iterations:", end - start)
print(
"Average Latency Per Step:",
(end - start) / inference_step / iterations / batch_size,
)
return image
return


def main(args):
print(f"Running flux_perfwith args: {args}")
pipe, backbone, trt_gm = compile_model(args)
for batch_size in range(1, args.max_batch_size + 1):
benchmark(pipe, ["Test"], 20, batch_size=batch_size, iterations=3)

benchmark(pipe, ["Test"], 20, batch_size=args.max_batch_size, iterations=3)
# profile(pipe, ["enchanted winter forest, soft diffuse light on a snow-filled day, serene nature scene, the forest is illuminated by the snow"], 20, batch_size=args.max_batch_size)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run Flux quantization with different dtypes"
)

parser.add_argument(
"--use_sdpa",
action="store_true",
help="Use sdpa",
default=False,
)
parser.add_argument(
"--dtype",
choices=["fp8", "int8", "fp16"],
choices=["fp4", "fp8", "int8", "fp16"],
default="fp16",
help="Select the data type to use (fp8 or int8 or fp16)",
help="Select the data type to use (fp4 or fp8 or int8 or fp16)",
)
parser.add_argument(
"--fp4_mha",
action="store_true",
help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_DEFAULT_CFG",
)
parser.add_argument(
"--low_vram_mode",
Expand Down
Loading