diff --git a/examples/apps/README.md b/examples/apps/README.md index ac63500d29..b6b77e17f1 100644 --- a/examples/apps/README.md +++ b/examples/apps/README.md @@ -23,6 +23,11 @@ python flux_demo.py ### Using Different Precision Modes +- FP4 mode: +```bash +python flux_demo.py --dtype fp4 +``` + - FP8 mode: ```bash python flux_demo.py --dtype fp8 diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index 4e8aaf3a4e..761b1bafa8 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -12,10 +12,6 @@ from diffusers import FluxPipeline from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel -# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py -sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo")) -from register_sdpa import * - DEVICE = "cuda:0" @@ -24,8 +20,23 @@ def compile_model( ) -> tuple[ FluxPipeline, FluxTransformer2DModel, torch_tensorrt.MutableTorchTensorRTModule ]: - - if args.dtype == "fp8": + use_explicit_typing = False + if args.use_sdpa: + # currently use sdpa is not working correctly with flux model, so we don't use it + # Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py + sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo")) + import register_sdpa + + if args.dtype == "fp4": + use_explicit_typing = True + enabled_precisions = {torch.float4_e2m1fn_x2} + ptq_config = mtq.NVFP4_DEFAULT_CFG + if args.fp4_mha: + from modelopt.core.torch.quantization.config import NVFP4_FP8_MHA_CONFIG + + ptq_config = NVFP4_FP8_MHA_CONFIG + + elif args.dtype == "fp8": enabled_precisions = {torch.float8_e4m3fn, torch.float16} ptq_config = mtq.FP8_DEFAULT_CFG @@ -107,26 +118,33 @@ def forward_loop(mod): "enabled_precisions": enabled_precisions, "truncate_double": True, "min_block_size": 1, - "use_python_runtime": True, + "use_python_runtime": False, "immutable_weights": False, - "offload_module_to_cpu": True, + "offload_module_to_cpu": args.low_vram_mode, + "use_explicit_typing": use_explicit_typing, } if args.low_vram_mode: pipe.remove_all_hooks() pipe.enable_sequential_cpu_offload() remove_hook_from_module(pipe.transformer, recurse=True) pipe.transformer.to(DEVICE) + trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) if dynamic_shapes: trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes) pipe.transformer = trt_gm - + seed = 42 image = pipe( - "Test", + [ + "enchanted winter forest, soft diffuse light on a snow-filled day, serene nature scene, the forest is illuminated by the snow" + ], output_type="pil", - num_inference_steps=2, + num_inference_steps=30, num_images_per_prompt=batch_size, + generator=torch.Generator("cuda").manual_seed(seed), ).images + print(f"generated {len(image)} images") + image[0].save("forest.png") torch.cuda.empty_cache() @@ -242,12 +260,22 @@ def main(args): parser = argparse.ArgumentParser( description="Run Flux quantization with different dtypes" ) - + parser.add_argument( + "--use_sdpa", + action="store_true", + help="Use sdpa", + default=False, + ) parser.add_argument( "--dtype", - choices=["fp8", "int8", "fp16"], + choices=["fp4", "fp8", "int8", "fp16"], default="fp16", - help="Select the data type to use (fp8 or int8 or fp16)", + help="Select the data type to use (fp4 or fp8 or int8 or fp16)", + ) + parser.add_argument( + "--fp4_mha", + action="store_true", + help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_DEFAULT_CFG", ) parser.add_argument( "--low_vram_mode", diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index ff7d3b7a07..74cab980c4 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -258,10 +258,11 @@ def cross_compile_for_windows( if use_explicit_typing: if len(enabled_precisions) != 1 or not any( - x in enabled_precisions for x in {torch.float32, dtype.f32} + x in enabled_precisions + for x in {torch.float32, dtype.f32, torch.float4_e2m1fn_x2, dtype.f4} ): raise AssertionError( - f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: {_defaults.ENABLED_PRECISIONS}). enabled_precisions should not be used when use_explicit_typing=True" + f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True" ) if use_fp32_acc: @@ -591,10 +592,11 @@ def compile( if use_explicit_typing: if len(enabled_precisions) != 1 or not any( - x in enabled_precisions for x in {torch.float32, dtype.f32} + x in enabled_precisions + for x in {torch.float32, dtype.f32, torch.float4_e2m1fn_x2, dtype.f4} ): raise AssertionError( - f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: {_defaults.ENABLED_PRECISIONS}). enabled_precisions should not be used when use_explicit_typing=True" + f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True" ) if use_fp32_acc: diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 928b7284fe..bd3baeebb7 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -106,7 +106,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: import modelopt.torch.quantization as mtq assert torch.ops.tensorrt.quantize_op.default + assert torch.ops.tensorrt.dynamic_block_quantize_op.default self.quantization_ops.add(torch.ops.tensorrt.quantize_op.default) + self.quantization_ops.add( + torch.ops.tensorrt.dynamic_block_quantize_op.default + ) except Exception as e: pass diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index ba83c89dcd..b0e41f7aeb 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -334,7 +334,7 @@ def export_fn() -> torch.export.ExportedProgram: # Check if any quantization precision is enabled if self.enabled_precisions and any( precision in self.enabled_precisions - for precision in (torch.float8_e4m3fn, torch.int8) + for precision in (torch.float8_e4m3fn, torch.int8, torch.float4_e2m1fn_x2) ): try: from modelopt.torch.quantization.utils import export_torch_mode diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py index e54952ea10..1d3b2acbbc 100644 --- a/tools/perf/Flux/flux_perf.py +++ b/tools/perf/Flux/flux_perf.py @@ -3,12 +3,29 @@ import sys from time import time +import torch + sys.path.append(os.path.join(os.path.dirname(__file__), "../../../examples/apps")) from flux_demo import compile_model def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): + print(f"Running warmup with {batch_size=} {inference_step=} iterations=10") + # warmup + for i in range(10): + start = time() + images = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + print( + f"Warmup {i} done in {time() - start} seconds, with {batch_size=} {inference_step=}, generated {len(images)} images" + ) + # actual benchmark + print(f"Running benchmark with {batch_size=} {inference_step=} {iterations=}") start = time() for i in range(iterations): image = pipe( @@ -18,32 +35,42 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): num_images_per_prompt=batch_size, ).images end = time() - print(f"Batch Size: {batch_size}") print("Time Elapse for", iterations, "iterations:", end - start) print( "Average Latency Per Step:", (end - start) / inference_step / iterations / batch_size, ) - return image + return def main(args): + print(f"Running flux_perfwith args: {args}") pipe, backbone, trt_gm = compile_model(args) - for batch_size in range(1, args.max_batch_size + 1): - benchmark(pipe, ["Test"], 20, batch_size=batch_size, iterations=3) + + benchmark(pipe, ["Test"], 20, batch_size=args.max_batch_size, iterations=3) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Run Flux quantization with different dtypes" ) - + parser.add_argument( + "--use_sdpa", + action="store_true", + help="Use sdpa", + default=False, + ) parser.add_argument( "--dtype", - choices=["fp8", "int8", "fp16"], + choices=["fp4", "fp8", "int8", "fp16"], default="fp16", - help="Select the data type to use (fp8 or int8 or fp16)", + help="Select the data type to use (fp4 or fp8 or int8 or fp16)", + ) + parser.add_argument( + "--fp4_mha", + action="store_true", + help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_DEFAULT_CFG", ) parser.add_argument( "--low_vram_mode",