From 33972b5c9775f83067103b98a9c35f3ac3505f75 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 30 Jun 2025 16:00:03 -0700 Subject: [PATCH 01/13] intial check in for flux fp4 example --- examples/apps/flux_demo.py | 33 ++++++++++++++++--- py/torch_tensorrt/dynamo/_compiler.py | 10 +++--- .../runtime/_MutableTorchTensorRTModule.py | 2 +- tools/perf/Flux/flux_perf.py | 30 ++++++++++++++--- 4 files changed, 61 insertions(+), 14 deletions(-) diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index 4e8aaf3a4e..2ac1adc5fe 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -24,8 +24,14 @@ def compile_model( ) -> tuple[ FluxPipeline, FluxTransformer2DModel, torch_tensorrt.MutableTorchTensorRTModule ]: - - if args.dtype == "fp8": + use_explicit_precision = False + if args.dtype == "fp4": + use_explicit_precision = True + enabled_precisions = {torch.float4_e2m1fn_x2} + ptq_config = mtq.FP4_DEFAULT_CFG + if args.fp4_mha: + ptq_config = mtq.NVFP4_FP8_MHA_CONFIG + elif args.dtype == "fp8": enabled_precisions = {torch.float8_e4m3fn, torch.float16} ptq_config = mtq.FP8_DEFAULT_CFG @@ -44,6 +50,12 @@ def compile_model( torch_dtype=torch.float16, ).to(torch.float16) + # Use a small transformer for debugging + if args.debug: + pipe.transformer = FluxTransformer2DModel( + num_layers=1, num_single_layers=1, guidance_embeds=True + ) + if args.low_vram_mode: pipe.enable_model_cpu_offload() else: @@ -109,7 +121,8 @@ def forward_loop(mod): "min_block_size": 1, "use_python_runtime": True, "immutable_weights": False, - "offload_module_to_cpu": True, + "offload_module_to_cpu": args.low_vram_mode, + "use_explicit_precision": use_explicit_precision, } if args.low_vram_mode: pipe.remove_all_hooks() @@ -245,9 +258,19 @@ def main(args): parser.add_argument( "--dtype", - choices=["fp8", "int8", "fp16"], + choices=["fp4", "fp8", "int8", "fp16"], default="fp16", - help="Select the data type to use (fp8 or int8 or fp16)", + help="Select the data type to use (fp4 or fp8 or int8 or fp16)", + ) + parser.add_argument( + "--fp4_mha", + action="store_true", + help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_FP8_MHA_CONFIG", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Use debug mode", ) parser.add_argument( "--low_vram_mode", diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 6434afe248..62e2472b97 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -258,10 +258,11 @@ def cross_compile_for_windows( if use_explicit_typing: if len(enabled_precisions) != 1 or not any( - x in enabled_precisions for x in {torch.float32, dtype.f32} + x in enabled_precisions + for x in {torch.float32, dtype.f32, torch.float4_e2m1fn_x2, dtype.f4} ): raise AssertionError( - f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: {_defaults.ENABLED_PRECISIONS}). enabled_precisions should not be used when use_explicit_typing=True" + f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True" ) if use_fp32_acc: @@ -591,10 +592,11 @@ def compile( if use_explicit_typing: if len(enabled_precisions) != 1 or not any( - x in enabled_precisions for x in {torch.float32, dtype.f32} + x in enabled_precisions + for x in {torch.float32, dtype.f32, torch.float4_e2m1fn_x2, dtype.f4} ): raise AssertionError( - f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: {_defaults.ENABLED_PRECISIONS}). enabled_precisions should not be used when use_explicit_typing=True" + f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True" ) if use_fp32_acc: diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index ba83c89dcd..b0e41f7aeb 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -334,7 +334,7 @@ def export_fn() -> torch.export.ExportedProgram: # Check if any quantization precision is enabled if self.enabled_precisions and any( precision in self.enabled_precisions - for precision in (torch.float8_e4m3fn, torch.int8) + for precision in (torch.float8_e4m3fn, torch.int8, torch.float4_e2m1fn_x2) ): try: from modelopt.torch.quantization.utils import export_torch_mode diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py index e54952ea10..d27bd22e76 100644 --- a/tools/perf/Flux/flux_perf.py +++ b/tools/perf/Flux/flux_perf.py @@ -30,8 +30,20 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): def main(args): pipe, backbone, trt_gm = compile_model(args) - for batch_size in range(1, args.max_batch_size + 1): - benchmark(pipe, ["Test"], 20, batch_size=batch_size, iterations=3) + # warmup + warmup_prompt = "Beach and Kids" + start = time() + images = pipe( + warmup_prompt, + output_type="pil", + num_inference_steps=20, + num_images_per_prompt=1, + ).images + print(f"Warmup done in {time() - start} seconds, generated {len(images)} images") + + if not args.debug: + for batch_size in range(1, args.max_batch_size + 1): + benchmark(pipe, ["Test"], 20, batch_size=batch_size, iterations=3) if __name__ == "__main__": @@ -41,9 +53,19 @@ def main(args): parser.add_argument( "--dtype", - choices=["fp8", "int8", "fp16"], + choices=["fp4", "fp8", "int8", "fp16"], default="fp16", - help="Select the data type to use (fp8 or int8 or fp16)", + help="Select the data type to use (fp4 or fp8 or int8 or fp16)", + ) + parser.add_argument( + "--fp4_mha", + action="store_true", + help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_FP8_MHA_CONFIG", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Use debug mode", ) parser.add_argument( "--low_vram_mode", From f3e604e25c288ba803b584f8cebe0488bb71675e Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 1 Jul 2025 08:57:54 -0700 Subject: [PATCH 02/13] test --- examples/apps/flux_demo.py | 48 +++++++++++++++++++++++++----------- tools/perf/Flux/flux_perf.py | 15 +++++++---- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index 2ac1adc5fe..4d5cd58415 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -11,6 +11,7 @@ from accelerate.hooks import remove_hook_from_module from diffusers import FluxPipeline from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR # Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo")) @@ -24,13 +25,16 @@ def compile_model( ) -> tuple[ FluxPipeline, FluxTransformer2DModel, torch_tensorrt.MutableTorchTensorRTModule ]: - use_explicit_precision = False + use_explicit_typing = False if args.dtype == "fp4": - use_explicit_precision = True + use_explicit_typing = True enabled_precisions = {torch.float4_e2m1fn_x2} - ptq_config = mtq.FP4_DEFAULT_CFG + ptq_config = mtq.NVFP4_DEFAULT_CFG if args.fp4_mha: - ptq_config = mtq.NVFP4_FP8_MHA_CONFIG + from modelopt.core.torch.quantization.config import NVFP4_FP8_MHA_CONFIG + + ptq_config = NVFP4_FP8_MHA_CONFIG + elif args.dtype == "fp8": enabled_precisions = {torch.float8_e4m3fn, torch.float16} ptq_config = mtq.FP8_DEFAULT_CFG @@ -50,11 +54,12 @@ def compile_model( torch_dtype=torch.float16, ).to(torch.float16) - # Use a small transformer for debugging - if args.debug: - pipe.transformer = FluxTransformer2DModel( - num_layers=1, num_single_layers=1, guidance_embeds=True - ) + # # Use a small transformer for debugging + # if args.debug: + # pipe.transformer = FluxTransformer2DModel( + # num_layers=1, num_single_layers=1, guidance_embeds=True + # ) + # pipe.to(torch.float16) if args.low_vram_mode: pipe.enable_model_cpu_offload() @@ -122,24 +127,39 @@ def forward_loop(mod): "use_python_runtime": True, "immutable_weights": False, "offload_module_to_cpu": args.low_vram_mode, - "use_explicit_precision": use_explicit_precision, + "use_explicit_typing": use_explicit_typing, } if args.low_vram_mode: pipe.remove_all_hooks() pipe.enable_sequential_cpu_offload() remove_hook_from_module(pipe.transformer, recurse=True) pipe.transformer.to(DEVICE) - trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) + + if args.debug: + with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir=DEBUG_LOGGING_DIR, + capture_fx_graph_after=["remove_num_users_is_0_nodes"], + save_engine_profile=True, + profile_format="trex", + engine_builder_monitor=True, + ): + trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) + else: + trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) if dynamic_shapes: trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes) pipe.transformer = trt_gm - + seed = 42 image = pipe( - "Test", + "Beach and Kids", output_type="pil", - num_inference_steps=2, + num_inference_steps=20, num_images_per_prompt=batch_size, + generator=torch.Generator("cuda").manual_seed(seed), ).images + print(f"generated {len(image)} images") + image[0].save("warmup1.png") torch.cuda.empty_cache() diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py index d27bd22e76..a4e6aa130b 100644 --- a/tools/perf/Flux/flux_perf.py +++ b/tools/perf/Flux/flux_perf.py @@ -3,6 +3,8 @@ import sys from time import time +import torch + sys.path.append(os.path.join(os.path.dirname(__file__), "../../../examples/apps")) from flux_demo import compile_model @@ -29,21 +31,24 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): def main(args): + print(f"Running flux_perfwith args: {args}") pipe, backbone, trt_gm = compile_model(args) # warmup + seed = 42 warmup_prompt = "Beach and Kids" start = time() images = pipe( warmup_prompt, output_type="pil", - num_inference_steps=20, - num_images_per_prompt=1, + num_inference_steps=30, + generator=torch.Generator("cuda").manual_seed(seed), ).images print(f"Warmup done in {time() - start} seconds, generated {len(images)} images") + images[0].save("warmup2.png") - if not args.debug: - for batch_size in range(1, args.max_batch_size + 1): - benchmark(pipe, ["Test"], 20, batch_size=batch_size, iterations=3) + # if not args.debug: + # for batch_size in range(1, args.max_batch_size + 1): + # benchmark(pipe, ["Test"], 20, batch_size=batch_size, iterations=3) if __name__ == "__main__": From 16ae00b5ee00954ff4670654d6ed4d74e5ae5699 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 1 Jul 2025 14:10:37 -0700 Subject: [PATCH 03/13] test --- examples/apps/flux_demo.py | 86 ++++++-- examples/apps/flux_quantization.py | 309 +++++++++++++++++++++++++++++ tools/perf/Flux/flux_perf.py | 8 +- 3 files changed, 386 insertions(+), 17 deletions(-) create mode 100644 examples/apps/flux_quantization.py diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index 4d5cd58415..40f5fdc098 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -134,25 +134,73 @@ def forward_loop(mod): pipe.enable_sequential_cpu_offload() remove_hook_from_module(pipe.transformer, recurse=True) pipe.transformer.to(DEVICE) - - if args.debug: - with torch_tensorrt.dynamo.Debugger( - "graphs", - logging_dir=DEBUG_LOGGING_DIR, - capture_fx_graph_after=["remove_num_users_is_0_nodes"], - save_engine_profile=True, - profile_format="trex", - engine_builder_monitor=True, - ): - trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) + if args.use_dynamo: + dummy_inputs = { + "hidden_states": torch.randn( + (batch_size, 4096, 64), dtype=torch.float16 + ).to(DEVICE), + "encoder_hidden_states": torch.randn( + (batch_size, 512, 4096), dtype=torch.float16 + ).to(DEVICE), + "pooled_projections": torch.randn( + (batch_size, 768), dtype=torch.float16 + ).to(DEVICE), + "timestep": torch.tensor([1.0] * batch_size, dtype=torch.float16).to( + DEVICE + ), + "txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE), + "img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE), + "guidance": torch.tensor([1.0] * batch_size, dtype=torch.float32).to( + DEVICE + ), + "joint_attention_kwargs": {}, + "return_dict": False, + } + from modelopt.torch.quantization.utils import export_torch_mode + + with export_torch_mode(): + ep = torch.export.export( + backbone, + args=(), + kwargs=dummy_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + if args.debug: + with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir=DEBUG_LOGGING_DIR, + capture_fx_graph_after=["remove_num_users_is_0_nodes"], + save_engine_profile=True, + profile_format="trex", + engine_builder_monitor=True, + ): + trt_gm = torch_tensorrt.dynamo.compile( + ep, inputs=dummy_inputs, **settings + ) + else: + trt_gm = torch_tensorrt.dynamo.compile(ep, inputs=dummy_inputs, **settings) + pipe.transformer = trt_gm + pipe.transformer.config = backbone.config else: - trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) - if dynamic_shapes: - trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes) - pipe.transformer = trt_gm + if args.debug: + with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir=DEBUG_LOGGING_DIR, + capture_fx_graph_after=["remove_num_users_is_0_nodes"], + save_engine_profile=True, + profile_format="trex", + engine_builder_monitor=True, + ): + trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) + else: + trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) + if dynamic_shapes: + trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes) + pipe.transformer = trt_gm seed = 42 image = pipe( - "Beach and Kids", + ["Beach and Kids"], output_type="pil", num_inference_steps=20, num_images_per_prompt=batch_size, @@ -282,6 +330,12 @@ def main(args): default="fp16", help="Select the data type to use (fp4 or fp8 or int8 or fp16)", ) + parser.add_argument( + "--use_dynamo", + action="store_true", + help="Use dynamo compile", + default=False, + ) parser.add_argument( "--fp4_mha", action="store_true", diff --git a/examples/apps/flux_quantization.py b/examples/apps/flux_quantization.py new file mode 100644 index 0000000000..87c75cf918 --- /dev/null +++ b/examples/apps/flux_quantization.py @@ -0,0 +1,309 @@ +# %% +# Import the following libraries +# ----------------------------- +# Load the ModelOpt-modified model architecture and weights using Huggingface APIs +# Add argument parsing for dtype selection +import argparse +import gc +import os +import re +import sys + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +import torch +import torch_tensorrt +from diffusers import FluxPipeline +from diffusers.models.attention_processor import Attention +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from modelopt.core.torch.quantization.config import NVFP4_FP8_MHA_CONFIG +from modelopt.torch.quantization.utils import export_torch_mode +from torch.export._trace import _export +from transformers import AutoModelForCausalLM + +parser = argparse.ArgumentParser( + description="Run Flux quantization with different dtypes" +) +parser.add_argument( + "--debug", + action="store_true", + default=False, + help="debug mode", +) +parser.add_argument( + "--mha", + action="store_true", + default=False, + help="NVFP4_FP8_MHA_CONFIG mode", +) +parser.add_argument( + "--dtype", + choices=["fp8", "int8", "fp4", "fp16", "bf16", "fp32"], + default="fp8", + help="Quantization data type to use (fp8 or int8 or fp4 or fp16 or bf16 or fp32)", +) + +parser.add_argument( + "--sdpa", + action="store_true", + default=False, + help="Register SDPA operator", +) + +parser.add_argument( + "--strong-typing", + action="store_true", + help="string type flag", +) + +args = parser.parse_args() +if args.sdpa: + sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo")) + from register_sdpa import * + + +dtype = torch.float16 +ptq_config = None +use_explicit_typing = args.strong_typing +enabled_precisions = [ + torch.float32, +] + +# Update enabled precisions based on dtype argument +if args.dtype == "fp8": + ( + enabled_precisions.extend([torch.float8_e4m3fn, torch.float16]) + if not use_explicit_typing + else None + ) + ptq_config = mtq.FP8_DEFAULT_CFG +elif args.dtype == "int8": # int8 + ( + enabled_precisions.extend([torch.int8, torch.float16]) + if not use_explicit_typing + else None + ) + ptq_config = mtq.INT8_DEFAULT_CFG +elif args.dtype == "fp4": + if args.mha: + ptq_config = NVFP4_FP8_MHA_CONFIG + else: + ptq_config = mtq.NVFP4_DEFAULT_CFG # mtq.NVFP4_DEFAULT_CFG + use_explicit_typing = True +elif args.dtype == "fp16": + enabled_precisions.append(torch.float16) if not use_explicit_typing else None +elif args.dtype == "bf16": + dtype = torch.bfloat16 + ( + enabled_precisions.extend([torch.bfloat16, torch.float16]) + if not use_explicit_typing + else None + ) +elif args.dtype == "fp32": + dtype = torch.float32 +else: + raise ValueError(f"Invalid dtype: {args.dtype}") +print(f"\nUsing {args.dtype} quantization with {args=}") +# %% +DEVICE = "cuda:0" +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=dtype, +) + +total_params = sum(p.numel() for p in pipe.transformer.parameters()) +print(f"\n Total number of parameters: {total_params/1000/1000/1000}B") +if dtype in (torch.float16, torch.bfloat16): + total_size = total_params * 2 / 1024 / 1024 / 1024 + print(f"\n Total size: {total_size}GB") +elif dtype == torch.float32: + total_size = total_params * 4 / 1024 / 1024 / 1024 + print(f"\n Total size: {total_size}GB") + +# if args.debug: +# pipe.transformer = FluxTransformer2DModel( +# num_layers=1, num_single_layers=1, guidance_embeds=True +# ) + +pipe.to(DEVICE).to(dtype) +# Store the config and transformer backbone +config = pipe.transformer.config +# global backbone +backbone = pipe.transformer +backbone.eval() + + +def filter_func(name): + pattern = re.compile( + r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" + ) + return pattern.match(name) is not None + + +def generate_image(pipe, prompt, image_name): + seed = 42 + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(seed), + ).images[0] + image.save(f"{image_name}.png") + print(f"Image generated using {image_name} model saved as {image_name}.png") + + +def benchmark(prompt, inference_step, batch_size=1, iterations=1): + from time import time + + print(f"Benchmark TRT Module Latency started with {batch_size=} {iterations=}") + start = time() + for i in range(iterations): + image = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + end = time() + print(f"Batch Size: {batch_size}") + print("Time Elapse for", iterations, "iterations:", end - start) + print( + "Average Latency Per Step:", + (end - start) / inference_step / iterations / batch_size, + ) + return image + + +# %% +# Quantization + + +def do_calibrate( + pipe, + prompt: str, +) -> None: + """ + Run calibration steps on the pipeline using the given prompts. + """ + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(0), + ).images[0] + + +def forward_loop(mod): + # Switch the pipeline's backbone, run calibration + pipe.transformer = mod + do_calibrate( + pipe=pipe, + prompt="test", + ) + + +if ptq_config is not None: + backbone = mtq.quantize(backbone, ptq_config, forward_loop) + mtq.disable_quantizer(backbone, filter_func) +else: + print("No quantization config provided, skipping quantization") + +batch_size = 2 +BATCH = torch.export.Dim("batch", min=1, max=8) +SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512) +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. +# To see this recommendation, you can try exporting using min=1, max=4096 +IMG_ID = torch.export.Dim("img_id", min=3586, max=4096) +dynamic_shapes = { + "hidden_states": {0: BATCH}, + "encoder_hidden_states": {0: BATCH, 1: SEQ_LEN}, + "pooled_projections": {0: BATCH}, + "timestep": {0: BATCH}, + "txt_ids": {0: SEQ_LEN}, + "img_ids": {0: IMG_ID}, + "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None, +} +# The guidance factor is of type torch.float32 +dummy_inputs = { + "hidden_states": torch.randn((batch_size, 4096, 64), dtype=dtype).to(DEVICE), + "encoder_hidden_states": torch.randn((batch_size, 512, 4096), dtype=dtype).to( + DEVICE + ), + "pooled_projections": torch.randn((batch_size, 768), dtype=dtype).to(DEVICE), + "timestep": torch.tensor([1.0] * batch_size, dtype=dtype).to(DEVICE), + "txt_ids": torch.randn((512, 3), dtype=dtype).to(DEVICE), + "img_ids": torch.randn((4096, 3), dtype=dtype).to(DEVICE), + "guidance": torch.tensor([1.0] * batch_size, dtype=torch.float32).to(DEVICE), + "joint_attention_kwargs": {}, + "return_dict": False, +} + + +torch.cuda.empty_cache() +torch.cuda.reset_peak_memory_stats() +gc.collect() +# This will create an exported program which is going to be compiled with Torch-TensorRT +with export_torch_mode(): + ep = _export( + backbone, + args=(), + kwargs=dummy_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + +peak_memory = torch.cuda.max_memory_allocated() / (1024**3) +peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) +print(f"Peak memory allocated during torch-export: {peak_memory=}GB {peak_reserved=}GB") + +torch.cuda.empty_cache() +torch.cuda.reset_peak_memory_stats() +gc.collect() + +with torch_tensorrt.logging.debug(): + trt_gm = torch_tensorrt.dynamo.compile( + ep, + inputs=dummy_inputs, + enabled_precisions=enabled_precisions, + use_explicit_typing=use_explicit_typing, + truncate_double=True, + min_block_size=1, + debug=args.debug, + immutable_weights=True, + offload_module_to_cpu=True, + ) + +peak_memory = torch.cuda.max_memory_allocated() / (1024**3) +peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) +print( + f"Peak memory allocated during torch dynamo compilation: {peak_memory=}GB {peak_reserved=}GB" +) + +del ep +pipe.transformer = trt_gm +pipe.transformer.config = config + +torch.cuda.empty_cache() +torch.cuda.reset_peak_memory_stats() +gc.collect() + +# %% + +trt_gm.device = torch.device(DEVICE) +# Function which generates images from the flux pipeline +generate_image(pipe, ["Beach and Kids"], "beach_and_kids") + +peak_memory = torch.cuda.max_memory_allocated() / (1024**3) +peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) +print(f"Peak memory allocated during inference: {peak_memory=}GB {peak_reserved=}GB") + +# if not args.debug: +# print(f"Benchmark TRT Module Latency at ({args.dtype}) started") +# for batch_size in range(1, 9): +# benchmark(["Test"], 20, batch_size=batch_size, iterations=3) +# print(f"Benchmark TRT Module Latency at ({args.dtype}) ended") + +# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py index a4e6aa130b..57c614da50 100644 --- a/tools/perf/Flux/flux_perf.py +++ b/tools/perf/Flux/flux_perf.py @@ -35,7 +35,7 @@ def main(args): pipe, backbone, trt_gm = compile_model(args) # warmup seed = 42 - warmup_prompt = "Beach and Kids" + warmup_prompt = ["Beach and Kids"] start = time() images = pipe( warmup_prompt, @@ -83,6 +83,12 @@ def main(args): action="store_true", help="Use dynamic shapes", ) + parser.add_argument( + "--use_dynamo", + action="store_true", + help="Use dynamo compile", + default=False, + ) parser.add_argument("--max_batch_size", type=int, default=1) args = parser.parse_args() main(args) From a28aafbb6bd33035eb072625f88f4cf79d01a522 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 1 Jul 2025 16:00:27 -0700 Subject: [PATCH 04/13] test --- examples/apps/flux_demo.py | 18 ++++++++---- .../runtime/_MutableTorchTensorRTModule.py | 28 +++++++++++++------ tools/perf/Flux/flux_perf.py | 7 ++++- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index 40f5fdc098..e012750c16 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -13,10 +13,6 @@ from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR -# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py -sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo")) -from register_sdpa import * - DEVICE = "cuda:0" @@ -26,6 +22,11 @@ def compile_model( FluxPipeline, FluxTransformer2DModel, torch_tensorrt.MutableTorchTensorRTModule ]: use_explicit_typing = False + if args.use_sdpa: + # Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py + sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo")) + import register_sdpa + if args.dtype == "fp4": use_explicit_typing = True enabled_precisions = {torch.float4_e2m1fn_x2} @@ -124,7 +125,7 @@ def forward_loop(mod): "enabled_precisions": enabled_precisions, "truncate_double": True, "min_block_size": 1, - "use_python_runtime": True, + "use_python_runtime": False, "immutable_weights": False, "offload_module_to_cpu": args.low_vram_mode, "use_explicit_typing": use_explicit_typing, @@ -323,7 +324,12 @@ def main(args): parser = argparse.ArgumentParser( description="Run Flux quantization with different dtypes" ) - + parser.add_argument( + "--use_sdpa", + action="store_true", + help="Use sdpa", + default=False, + ) parser.add_argument( "--dtype", choices=["fp4", "fp8", "int8", "fp16"], diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index b0e41f7aeb..47fb630223 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -359,15 +359,25 @@ def compile(self) -> None: # Export the module self.original_model.to(to_torch_device(self.trt_device)) self.exp_program = self.get_exported_program() - self.gm = dynamo_compile( - self.exp_program, - arg_inputs=self.arg_inputs, - kwarg_inputs=self.kwarg_inputs, - immutable_weights=False, - use_python_runtime=self.use_python_runtime, - enabled_precisions=self.enabled_precisions, - **self.additional_settings, - ) + from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR + + with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir=DEBUG_LOGGING_DIR, + capture_fx_graph_after=["remove_num_users_is_0_nodes"], + save_engine_profile=True, + profile_format="trex", + engine_builder_monitor=True, + ): + self.gm = dynamo_compile( + self.exp_program, + arg_inputs=self.arg_inputs, + kwarg_inputs=self.kwarg_inputs, + immutable_weights=False, + use_python_runtime=self.use_python_runtime, + enabled_precisions=self.enabled_precisions, + **self.additional_settings, + ) deallocate_module(self.original_model, delete_module=False) if self.enable_weight_streaming: self.set_weight_streaming_ctx(self.weight_streaming_budget) diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py index 57c614da50..0b382f0547 100644 --- a/tools/perf/Flux/flux_perf.py +++ b/tools/perf/Flux/flux_perf.py @@ -55,7 +55,12 @@ def main(args): parser = argparse.ArgumentParser( description="Run Flux quantization with different dtypes" ) - + parser.add_argument( + "--use_sdpa", + action="store_true", + help="Use sdpa", + default=False, + ) parser.add_argument( "--dtype", choices=["fp4", "fp8", "int8", "fp16"], From 92ca4ebd19dd858695ac9c18ff3517a1cc739d8a Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 2 Jul 2025 13:46:05 -0700 Subject: [PATCH 05/13] remove sdpa --- examples/apps/flux_demo.py | 2 +- .../dynamo/runtime/_MutableTorchTensorRTModule.py | 3 ++- tools/perf/Flux/flux_perf.py | 6 +++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index e012750c16..944f40df89 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -171,7 +171,7 @@ def forward_loop(mod): with torch_tensorrt.dynamo.Debugger( "graphs", logging_dir=DEBUG_LOGGING_DIR, - capture_fx_graph_after=["remove_num_users_is_0_nodes"], + # capture_fx_graph_after=["remove_num_users_is_0_nodes"], save_engine_profile=True, profile_format="trex", engine_builder_monitor=True, diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index 47fb630223..092986fc1b 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -364,7 +364,8 @@ def compile(self) -> None: with torch_tensorrt.dynamo.Debugger( "graphs", logging_dir=DEBUG_LOGGING_DIR, - capture_fx_graph_after=["remove_num_users_is_0_nodes"], + # whenever try to draw svg it hangs + # capture_fx_graph_after=["remove_num_users_is_0_nodes"], save_engine_profile=True, profile_format="trex", engine_builder_monitor=True, diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py index 0b382f0547..50f70500e3 100644 --- a/tools/perf/Flux/flux_perf.py +++ b/tools/perf/Flux/flux_perf.py @@ -46,9 +46,9 @@ def main(args): print(f"Warmup done in {time() - start} seconds, generated {len(images)} images") images[0].save("warmup2.png") - # if not args.debug: - # for batch_size in range(1, args.max_batch_size + 1): - # benchmark(pipe, ["Test"], 20, batch_size=batch_size, iterations=3) + if not args.debug: + for batch_size in range(1, args.max_batch_size + 1): + benchmark(pipe, ["Test"], 20, batch_size=batch_size, iterations=3) if __name__ == "__main__": From 0421c580e1ff5ebb2dc14936d28e05aacace929b Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 2 Jul 2025 16:46:01 -0700 Subject: [PATCH 06/13] test --- examples/dynamo/register_sdpa.py | 215 ++++++++++++------ .../dynamo/conversion/_TRTInterpreter.py | 4 +- .../runtime/_MutableTorchTensorRTModule.py | 31 +-- tools/perf/Flux/flux_perf.py | 45 ++-- 4 files changed, 180 insertions(+), 115 deletions(-) diff --git a/examples/dynamo/register_sdpa.py b/examples/dynamo/register_sdpa.py index 7436f31939..9afbbda7d4 100644 --- a/examples/dynamo/register_sdpa.py +++ b/examples/dynamo/register_sdpa.py @@ -32,88 +32,153 @@ @_aten_lowering_pass -def replace_variants_of_sdpa( +def lower_scaled_dot_product_attention( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: - """Replace scaled_dot_product_attention with an equivalent - implementation which can be accurately converted to TRT + """Replace specific versions of scaled_dot_product_attention with an equivalent + implementation which can be easily converted to TRT """ - attn_mask = None - is_causal = True - for node in gm.graph.nodes: - if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: + original_fns, replacement = scaled_dot_product_attention_replacement() + replaced_nodes = [] + # For each original function, search for it in the graph and replace + for original in original_fns: + replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters( + gm, + original, + replacement, + ignore_literals=True, + ) + + if replaced_nodes: + # Repair instances which use the kwargs field (specifically the "scale" kwarg) + # Also repair instances which specified the is_causal or attn_bias fields + for match in replaced_nodes: + attention_node_replaced = None + # Seek the attention operator being replaced + for node in match.nodes_map: + if node.target in REPLACEABLE_ATEN_OPS: + attention_node_replaced = match.nodes_map[node] + break + + assert attention_node_replaced is not None + assert len(match.replacements) == 1 + + new_attention_node = match.replacements[0] + + assert ( + new_attention_node.target + == torch.nn.functional.scaled_dot_product_attention + ) + + # Copy the metadata of the replaced attention node to the new node + # TODO: Investigate why there are multiple FakeTensors in the metadata. + # We only use the first one as it contains the output shape information for this node. + if "val" in attention_node_replaced.meta: + new_attention_node.meta["val"] = copy.copy( + attention_node_replaced.meta["val"][0] + ) + + # If the attention operator had keyword-args, copy them to the new node + if attention_node_replaced.kwargs: + new_attention_node.kwargs = {**attention_node_replaced.kwargs} + + # Set default args in new node: + # Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False + new_attention_node.args = new_attention_node.args + (None, 0.0, False) + + # The `is_causal` argument was specified if ( - node.target - == torch.ops.aten._scaled_dot_product_efficient_attention.default - ): - if len(node.args) == 7: - ( - query, - key, - value, - attn_bias, - compute_log_sumexp, - dropout_p, - is_causal, - ) = node.args - elif len(node.args) == 5: - query, key, value, attn_mask, is_causal = node.args - dropout_p = 0.0 - else: - raise ValueError( - f"Unexpected number of arguments for {node.target} in the graph" - ) - elif ( - node.target - == torch.ops.aten._scaled_dot_product_flash_attention.default + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_flash_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 4, False) + ) or ( + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 6, False) ): - if len(node.args) == 6: - query, key, value, dropout_p, is_causal, return_debug_mask = ( - node.args - ) - elif len(node.args) == 3: - query, key, value = node.args - dropout_p = 0.0 - is_causal = True - else: - raise ValueError( - f"Unexpected number of arguments for {node.target} in the graph" - ) - if attn_mask is not None: - logger.warning( - f"This current version of SDPA converter does not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration." + new_attention_node.args = ( + new_attention_node.args[:5] + (True,) + new_attention_node.args[6:] ) - modified_input_args = (query, key, value, None, dropout_p, is_causal) - - # Create a new node with torch.nn.functional.scaled_dot_product_attention - # The input args is (query, key, value, is_causal). kwargs has scale - with gm.graph.inserting_after(node): - new_node = gm.graph.call_function( - torch.nn.functional.scaled_dot_product_attention, - args=modified_input_args, - kwargs={"scale": node.kwargs.get("scale", None)}, + # The `attn_bias` argument was specified + if ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) and args_bounds_check(attention_node_replaced.args, 3) is not None: + new_attention_node.args = ( + new_attention_node.args[:3] + + attention_node_replaced.args[3] + + new_attention_node.args[4:] ) - # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. - new_node.meta = copy.copy(node.meta) - # Check if there's a getitem node following this attention node - for user in list(node.users): - if user.op == "call_function" and user.target == operator.getitem: - # If the getitem is extracting the first element (the output tensor) - if user.args[1] == 0: - # Replace all uses of the getitem with the new attention node - user.replace_all_uses_with(new_node) - new_node.meta["val"] = new_node.meta["val"][0] - # Replace all uses of the original node with the new node - node.replace_all_uses_with(new_node) - - gm.graph.erase_node(node) - - # Clean up the graph - clean_up_graph_after_modifications(gm) - - logger.info( - "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention" - ) + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}") + return gm + + +def scaled_dot_product_attention_replacement() -> Tuple[ + Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], +]: + """Constructs the original and replacement functions for efficient attention""" + + # Efficient Attention original graph + def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention original graph + def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + ) + out = operator.getitem(outputs, 0) + return out + + # Efficient Attention w/Scale original graph + def efficient_scale( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention w/Scale original graph + def flash_scale(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Replacement graph consists of the functional version of scaled_dot_product_attention + def replacement( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention(query, key, value) + + return (efficient, flash, efficient_scale, flash_scale), replacement diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index b134b3d5f5..f16271cebc 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -220,8 +220,8 @@ def _populate_trt_builder_config( if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( trt.ProfilingVerbosity.DETAILED - if self._debugger_config and self._debugger_config.save_engine_profile - else trt.ProfilingVerbosity.LAYER_NAMES_ONLY + # if self._debugger_config and self._debugger_config.save_engine_profile + # else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) if version.parse(trt.__version__) >= version.parse("8.6"): diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index 092986fc1b..ba83c89dcd 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -334,7 +334,7 @@ def export_fn() -> torch.export.ExportedProgram: # Check if any quantization precision is enabled if self.enabled_precisions and any( precision in self.enabled_precisions - for precision in (torch.float8_e4m3fn, torch.int8, torch.float4_e2m1fn_x2) + for precision in (torch.float8_e4m3fn, torch.int8) ): try: from modelopt.torch.quantization.utils import export_torch_mode @@ -359,26 +359,15 @@ def compile(self) -> None: # Export the module self.original_model.to(to_torch_device(self.trt_device)) self.exp_program = self.get_exported_program() - from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR - - with torch_tensorrt.dynamo.Debugger( - "graphs", - logging_dir=DEBUG_LOGGING_DIR, - # whenever try to draw svg it hangs - # capture_fx_graph_after=["remove_num_users_is_0_nodes"], - save_engine_profile=True, - profile_format="trex", - engine_builder_monitor=True, - ): - self.gm = dynamo_compile( - self.exp_program, - arg_inputs=self.arg_inputs, - kwarg_inputs=self.kwarg_inputs, - immutable_weights=False, - use_python_runtime=self.use_python_runtime, - enabled_precisions=self.enabled_precisions, - **self.additional_settings, - ) + self.gm = dynamo_compile( + self.exp_program, + arg_inputs=self.arg_inputs, + kwarg_inputs=self.kwarg_inputs, + immutable_weights=False, + use_python_runtime=self.use_python_runtime, + enabled_precisions=self.enabled_precisions, + **self.additional_settings, + ) deallocate_module(self.original_model, delete_module=False) if self.enable_weight_streaming: self.set_weight_streaming_ctx(self.weight_streaming_budget) diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py index 50f70500e3..8a74392ea9 100644 --- a/tools/perf/Flux/flux_perf.py +++ b/tools/perf/Flux/flux_perf.py @@ -10,7 +10,20 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): + # warmup + for i in range(10): + start = time() + images = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + print( + f"Warmup {i} done in {time() - start} seconds, with {batch_size=} {inference_step=}, generated {len(images)} images" + ) + # actual benchmark start = time() for i in range(iterations): image = pipe( @@ -20,35 +33,33 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): num_images_per_prompt=batch_size, ).images end = time() - print(f"Batch Size: {batch_size}") print("Time Elapse for", iterations, "iterations:", end - start) print( "Average Latency Per Step:", (end - start) / inference_step / iterations / batch_size, ) - return image + + # run the perf tool + from cuda import cudart + + cudart.cudaProfilerStart() + image = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + cudart.cudaProfilerStop() + return def main(args): print(f"Running flux_perfwith args: {args}") pipe, backbone, trt_gm = compile_model(args) - # warmup - seed = 42 - warmup_prompt = ["Beach and Kids"] - start = time() - images = pipe( - warmup_prompt, - output_type="pil", - num_inference_steps=30, - generator=torch.Generator("cuda").manual_seed(seed), - ).images - print(f"Warmup done in {time() - start} seconds, generated {len(images)} images") - images[0].save("warmup2.png") - if not args.debug: - for batch_size in range(1, args.max_batch_size + 1): - benchmark(pipe, ["Test"], 20, batch_size=batch_size, iterations=3) + for batch_size in range(1, args.max_batch_size + 1): + benchmark(pipe, ["Test"], 20, batch_size=batch_size, iterations=3) if __name__ == "__main__": From 857720b1fdad25fffa648f907c278d1c1125bdaf Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 3 Jul 2025 08:37:21 -0700 Subject: [PATCH 07/13] add profiler --- tools/perf/Flux/flux_perf.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py index 8a74392ea9..449e1ff126 100644 --- a/tools/perf/Flux/flux_perf.py +++ b/tools/perf/Flux/flux_perf.py @@ -10,6 +10,7 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): + print(f"Running warmup with {batch_size=} {inference_step=} iterations=10") # warmup for i in range(10): start = time() @@ -24,6 +25,7 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): ) # actual benchmark + print(f"Running benchmark with {batch_size=} {inference_step=} {iterations=}") start = time() for i in range(iterations): image = pipe( @@ -41,8 +43,10 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): ) # run the perf tool + print(f"Running cudart perf tool with {inference_step=} {batch_size=}") from cuda import cudart + cudart.cudaInit(0) cudart.cudaProfilerStart() image = pipe( prompt, @@ -51,6 +55,22 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): num_images_per_prompt=batch_size, ).images cudart.cudaProfilerStop() + + # print(f"Running torch profiler with {inference_step=} {batch_size=}") + # with torch.profiler.profile( + # activities=[torch.profiler.ProfilerActivity.CUDA], + # record_shapes=True, + # profile_memory=True, + # with_stack=True, + # ) as prof: + # with torch.profiler.record_function("model_inference"): + # pipe( + # prompt, + # output_type="pil", + # num_inference_steps=inference_step, + # num_images_per_prompt=batch_size, + # ).images + # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100)) return From 6dc056da3be72bbfa4e00df807ce32f664dba778 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 3 Jul 2025 10:57:17 -0700 Subject: [PATCH 08/13] test --- .../dynamo/runtime/_MutableTorchTensorRTModule.py | 2 +- tools/perf/Flux/flux_perf.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index ba83c89dcd..b0e41f7aeb 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -334,7 +334,7 @@ def export_fn() -> torch.export.ExportedProgram: # Check if any quantization precision is enabled if self.enabled_precisions and any( precision in self.enabled_precisions - for precision in (torch.float8_e4m3fn, torch.int8) + for precision in (torch.float8_e4m3fn, torch.int8, torch.float4_e2m1fn_x2) ): try: from modelopt.torch.quantization.utils import export_torch_mode diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py index 449e1ff126..fd19e5c0eb 100644 --- a/tools/perf/Flux/flux_perf.py +++ b/tools/perf/Flux/flux_perf.py @@ -46,7 +46,6 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): print(f"Running cudart perf tool with {inference_step=} {batch_size=}") from cuda import cudart - cudart.cudaInit(0) cudart.cudaProfilerStart() image = pipe( prompt, @@ -78,8 +77,7 @@ def main(args): print(f"Running flux_perfwith args: {args}") pipe, backbone, trt_gm = compile_model(args) - for batch_size in range(1, args.max_batch_size + 1): - benchmark(pipe, ["Test"], 20, batch_size=batch_size, iterations=3) + benchmark(pipe, ["Test"], 20, batch_size=args.max_batch_size, iterations=3) if __name__ == "__main__": From 3b618794aefa0c7cfbfeaebe60a1e145913c102e Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 3 Jul 2025 16:48:15 -0700 Subject: [PATCH 09/13] remove profile, debug flag --- examples/apps/README.md | 5 + examples/apps/flux_demo.py | 92 +-------- examples/apps/flux_quantization.py | 309 ----------------------------- tools/perf/Flux/flux_perf.py | 37 ---- 4 files changed, 12 insertions(+), 431 deletions(-) delete mode 100644 examples/apps/flux_quantization.py diff --git a/examples/apps/README.md b/examples/apps/README.md index ac63500d29..b6b77e17f1 100644 --- a/examples/apps/README.md +++ b/examples/apps/README.md @@ -23,6 +23,11 @@ python flux_demo.py ### Using Different Precision Modes +- FP4 mode: +```bash +python flux_demo.py --dtype fp4 +``` + - FP8 mode: ```bash python flux_demo.py --dtype fp8 diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index 944f40df89..fbc44e0f6f 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -10,8 +10,6 @@ import torch_tensorrt from accelerate.hooks import remove_hook_from_module from diffusers import FluxPipeline -from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel -from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR DEVICE = "cuda:0" @@ -23,6 +21,7 @@ def compile_model( ]: use_explicit_typing = False if args.use_sdpa: + # currently use sdpa is not working correctly with flux model, so we don't use it # Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo")) import register_sdpa @@ -55,13 +54,6 @@ def compile_model( torch_dtype=torch.float16, ).to(torch.float16) - # # Use a small transformer for debugging - # if args.debug: - # pipe.transformer = FluxTransformer2DModel( - # num_layers=1, num_single_layers=1, guidance_embeds=True - # ) - # pipe.to(torch.float16) - if args.low_vram_mode: pipe.enable_model_cpu_offload() else: @@ -135,70 +127,11 @@ def forward_loop(mod): pipe.enable_sequential_cpu_offload() remove_hook_from_module(pipe.transformer, recurse=True) pipe.transformer.to(DEVICE) - if args.use_dynamo: - dummy_inputs = { - "hidden_states": torch.randn( - (batch_size, 4096, 64), dtype=torch.float16 - ).to(DEVICE), - "encoder_hidden_states": torch.randn( - (batch_size, 512, 4096), dtype=torch.float16 - ).to(DEVICE), - "pooled_projections": torch.randn( - (batch_size, 768), dtype=torch.float16 - ).to(DEVICE), - "timestep": torch.tensor([1.0] * batch_size, dtype=torch.float16).to( - DEVICE - ), - "txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE), - "img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE), - "guidance": torch.tensor([1.0] * batch_size, dtype=torch.float32).to( - DEVICE - ), - "joint_attention_kwargs": {}, - "return_dict": False, - } - from modelopt.torch.quantization.utils import export_torch_mode - - with export_torch_mode(): - ep = torch.export.export( - backbone, - args=(), - kwargs=dummy_inputs, - dynamic_shapes=dynamic_shapes, - strict=False, - ) - if args.debug: - with torch_tensorrt.dynamo.Debugger( - "graphs", - logging_dir=DEBUG_LOGGING_DIR, - # capture_fx_graph_after=["remove_num_users_is_0_nodes"], - save_engine_profile=True, - profile_format="trex", - engine_builder_monitor=True, - ): - trt_gm = torch_tensorrt.dynamo.compile( - ep, inputs=dummy_inputs, **settings - ) - else: - trt_gm = torch_tensorrt.dynamo.compile(ep, inputs=dummy_inputs, **settings) - pipe.transformer = trt_gm - pipe.transformer.config = backbone.config - else: - if args.debug: - with torch_tensorrt.dynamo.Debugger( - "graphs", - logging_dir=DEBUG_LOGGING_DIR, - capture_fx_graph_after=["remove_num_users_is_0_nodes"], - save_engine_profile=True, - profile_format="trex", - engine_builder_monitor=True, - ): - trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) - else: - trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) - if dynamic_shapes: - trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes) - pipe.transformer = trt_gm + + trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) + if dynamic_shapes: + trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes) + pipe.transformer = trt_gm seed = 42 image = pipe( ["Beach and Kids"], @@ -208,7 +141,7 @@ def forward_loop(mod): generator=torch.Generator("cuda").manual_seed(seed), ).images print(f"generated {len(image)} images") - image[0].save("warmup1.png") + image[0].save("beach_kids.png") torch.cuda.empty_cache() @@ -336,22 +269,11 @@ def main(args): default="fp16", help="Select the data type to use (fp4 or fp8 or int8 or fp16)", ) - parser.add_argument( - "--use_dynamo", - action="store_true", - help="Use dynamo compile", - default=False, - ) parser.add_argument( "--fp4_mha", action="store_true", help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_FP8_MHA_CONFIG", ) - parser.add_argument( - "--debug", - action="store_true", - help="Use debug mode", - ) parser.add_argument( "--low_vram_mode", action="store_true", diff --git a/examples/apps/flux_quantization.py b/examples/apps/flux_quantization.py deleted file mode 100644 index 87c75cf918..0000000000 --- a/examples/apps/flux_quantization.py +++ /dev/null @@ -1,309 +0,0 @@ -# %% -# Import the following libraries -# ----------------------------- -# Load the ModelOpt-modified model architecture and weights using Huggingface APIs -# Add argument parsing for dtype selection -import argparse -import gc -import os -import re -import sys - -import modelopt.torch.opt as mto -import modelopt.torch.quantization as mtq -import torch -import torch_tensorrt -from diffusers import FluxPipeline -from diffusers.models.attention_processor import Attention -from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel -from modelopt.core.torch.quantization.config import NVFP4_FP8_MHA_CONFIG -from modelopt.torch.quantization.utils import export_torch_mode -from torch.export._trace import _export -from transformers import AutoModelForCausalLM - -parser = argparse.ArgumentParser( - description="Run Flux quantization with different dtypes" -) -parser.add_argument( - "--debug", - action="store_true", - default=False, - help="debug mode", -) -parser.add_argument( - "--mha", - action="store_true", - default=False, - help="NVFP4_FP8_MHA_CONFIG mode", -) -parser.add_argument( - "--dtype", - choices=["fp8", "int8", "fp4", "fp16", "bf16", "fp32"], - default="fp8", - help="Quantization data type to use (fp8 or int8 or fp4 or fp16 or bf16 or fp32)", -) - -parser.add_argument( - "--sdpa", - action="store_true", - default=False, - help="Register SDPA operator", -) - -parser.add_argument( - "--strong-typing", - action="store_true", - help="string type flag", -) - -args = parser.parse_args() -if args.sdpa: - sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo")) - from register_sdpa import * - - -dtype = torch.float16 -ptq_config = None -use_explicit_typing = args.strong_typing -enabled_precisions = [ - torch.float32, -] - -# Update enabled precisions based on dtype argument -if args.dtype == "fp8": - ( - enabled_precisions.extend([torch.float8_e4m3fn, torch.float16]) - if not use_explicit_typing - else None - ) - ptq_config = mtq.FP8_DEFAULT_CFG -elif args.dtype == "int8": # int8 - ( - enabled_precisions.extend([torch.int8, torch.float16]) - if not use_explicit_typing - else None - ) - ptq_config = mtq.INT8_DEFAULT_CFG -elif args.dtype == "fp4": - if args.mha: - ptq_config = NVFP4_FP8_MHA_CONFIG - else: - ptq_config = mtq.NVFP4_DEFAULT_CFG # mtq.NVFP4_DEFAULT_CFG - use_explicit_typing = True -elif args.dtype == "fp16": - enabled_precisions.append(torch.float16) if not use_explicit_typing else None -elif args.dtype == "bf16": - dtype = torch.bfloat16 - ( - enabled_precisions.extend([torch.bfloat16, torch.float16]) - if not use_explicit_typing - else None - ) -elif args.dtype == "fp32": - dtype = torch.float32 -else: - raise ValueError(f"Invalid dtype: {args.dtype}") -print(f"\nUsing {args.dtype} quantization with {args=}") -# %% -DEVICE = "cuda:0" -pipe = FluxPipeline.from_pretrained( - "black-forest-labs/FLUX.1-dev", - torch_dtype=dtype, -) - -total_params = sum(p.numel() for p in pipe.transformer.parameters()) -print(f"\n Total number of parameters: {total_params/1000/1000/1000}B") -if dtype in (torch.float16, torch.bfloat16): - total_size = total_params * 2 / 1024 / 1024 / 1024 - print(f"\n Total size: {total_size}GB") -elif dtype == torch.float32: - total_size = total_params * 4 / 1024 / 1024 / 1024 - print(f"\n Total size: {total_size}GB") - -# if args.debug: -# pipe.transformer = FluxTransformer2DModel( -# num_layers=1, num_single_layers=1, guidance_embeds=True -# ) - -pipe.to(DEVICE).to(dtype) -# Store the config and transformer backbone -config = pipe.transformer.config -# global backbone -backbone = pipe.transformer -backbone.eval() - - -def filter_func(name): - pattern = re.compile( - r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" - ) - return pattern.match(name) is not None - - -def generate_image(pipe, prompt, image_name): - seed = 42 - image = pipe( - prompt, - output_type="pil", - num_inference_steps=20, - generator=torch.Generator("cuda").manual_seed(seed), - ).images[0] - image.save(f"{image_name}.png") - print(f"Image generated using {image_name} model saved as {image_name}.png") - - -def benchmark(prompt, inference_step, batch_size=1, iterations=1): - from time import time - - print(f"Benchmark TRT Module Latency started with {batch_size=} {iterations=}") - start = time() - for i in range(iterations): - image = pipe( - prompt, - output_type="pil", - num_inference_steps=inference_step, - num_images_per_prompt=batch_size, - ).images - end = time() - print(f"Batch Size: {batch_size}") - print("Time Elapse for", iterations, "iterations:", end - start) - print( - "Average Latency Per Step:", - (end - start) / inference_step / iterations / batch_size, - ) - return image - - -# %% -# Quantization - - -def do_calibrate( - pipe, - prompt: str, -) -> None: - """ - Run calibration steps on the pipeline using the given prompts. - """ - image = pipe( - prompt, - output_type="pil", - num_inference_steps=20, - generator=torch.Generator("cuda").manual_seed(0), - ).images[0] - - -def forward_loop(mod): - # Switch the pipeline's backbone, run calibration - pipe.transformer = mod - do_calibrate( - pipe=pipe, - prompt="test", - ) - - -if ptq_config is not None: - backbone = mtq.quantize(backbone, ptq_config, forward_loop) - mtq.disable_quantizer(backbone, filter_func) -else: - print("No quantization config provided, skipping quantization") - -batch_size = 2 -BATCH = torch.export.Dim("batch", min=1, max=8) -SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512) -# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. -# To see this recommendation, you can try exporting using min=1, max=4096 -IMG_ID = torch.export.Dim("img_id", min=3586, max=4096) -dynamic_shapes = { - "hidden_states": {0: BATCH}, - "encoder_hidden_states": {0: BATCH, 1: SEQ_LEN}, - "pooled_projections": {0: BATCH}, - "timestep": {0: BATCH}, - "txt_ids": {0: SEQ_LEN}, - "img_ids": {0: IMG_ID}, - "guidance": {0: BATCH}, - "joint_attention_kwargs": {}, - "return_dict": None, -} -# The guidance factor is of type torch.float32 -dummy_inputs = { - "hidden_states": torch.randn((batch_size, 4096, 64), dtype=dtype).to(DEVICE), - "encoder_hidden_states": torch.randn((batch_size, 512, 4096), dtype=dtype).to( - DEVICE - ), - "pooled_projections": torch.randn((batch_size, 768), dtype=dtype).to(DEVICE), - "timestep": torch.tensor([1.0] * batch_size, dtype=dtype).to(DEVICE), - "txt_ids": torch.randn((512, 3), dtype=dtype).to(DEVICE), - "img_ids": torch.randn((4096, 3), dtype=dtype).to(DEVICE), - "guidance": torch.tensor([1.0] * batch_size, dtype=torch.float32).to(DEVICE), - "joint_attention_kwargs": {}, - "return_dict": False, -} - - -torch.cuda.empty_cache() -torch.cuda.reset_peak_memory_stats() -gc.collect() -# This will create an exported program which is going to be compiled with Torch-TensorRT -with export_torch_mode(): - ep = _export( - backbone, - args=(), - kwargs=dummy_inputs, - dynamic_shapes=dynamic_shapes, - strict=False, - allow_complex_guards_as_runtime_asserts=True, - ) - -peak_memory = torch.cuda.max_memory_allocated() / (1024**3) -peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) -print(f"Peak memory allocated during torch-export: {peak_memory=}GB {peak_reserved=}GB") - -torch.cuda.empty_cache() -torch.cuda.reset_peak_memory_stats() -gc.collect() - -with torch_tensorrt.logging.debug(): - trt_gm = torch_tensorrt.dynamo.compile( - ep, - inputs=dummy_inputs, - enabled_precisions=enabled_precisions, - use_explicit_typing=use_explicit_typing, - truncate_double=True, - min_block_size=1, - debug=args.debug, - immutable_weights=True, - offload_module_to_cpu=True, - ) - -peak_memory = torch.cuda.max_memory_allocated() / (1024**3) -peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) -print( - f"Peak memory allocated during torch dynamo compilation: {peak_memory=}GB {peak_reserved=}GB" -) - -del ep -pipe.transformer = trt_gm -pipe.transformer.config = config - -torch.cuda.empty_cache() -torch.cuda.reset_peak_memory_stats() -gc.collect() - -# %% - -trt_gm.device = torch.device(DEVICE) -# Function which generates images from the flux pipeline -generate_image(pipe, ["Beach and Kids"], "beach_and_kids") - -peak_memory = torch.cuda.max_memory_allocated() / (1024**3) -peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) -print(f"Peak memory allocated during inference: {peak_memory=}GB {peak_reserved=}GB") - -# if not args.debug: -# print(f"Benchmark TRT Module Latency at ({args.dtype}) started") -# for batch_size in range(1, 9): -# benchmark(["Test"], 20, batch_size=batch_size, iterations=3) -# print(f"Benchmark TRT Module Latency at ({args.dtype}) ended") - -# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py index fd19e5c0eb..c654b06d9b 100644 --- a/tools/perf/Flux/flux_perf.py +++ b/tools/perf/Flux/flux_perf.py @@ -44,32 +44,6 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): # run the perf tool print(f"Running cudart perf tool with {inference_step=} {batch_size=}") - from cuda import cudart - - cudart.cudaProfilerStart() - image = pipe( - prompt, - output_type="pil", - num_inference_steps=inference_step, - num_images_per_prompt=batch_size, - ).images - cudart.cudaProfilerStop() - - # print(f"Running torch profiler with {inference_step=} {batch_size=}") - # with torch.profiler.profile( - # activities=[torch.profiler.ProfilerActivity.CUDA], - # record_shapes=True, - # profile_memory=True, - # with_stack=True, - # ) as prof: - # with torch.profiler.record_function("model_inference"): - # pipe( - # prompt, - # output_type="pil", - # num_inference_steps=inference_step, - # num_images_per_prompt=batch_size, - # ).images - # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100)) return @@ -101,11 +75,6 @@ def main(args): action="store_true", help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_FP8_MHA_CONFIG", ) - parser.add_argument( - "--debug", - action="store_true", - help="Use debug mode", - ) parser.add_argument( "--low_vram_mode", action="store_true", @@ -117,12 +86,6 @@ def main(args): action="store_true", help="Use dynamic shapes", ) - parser.add_argument( - "--use_dynamo", - action="store_true", - help="Use dynamo compile", - default=False, - ) parser.add_argument("--max_batch_size", type=int, default=1) args = parser.parse_args() main(args) From 8a3afdeb09e8aae74a102c3a1707aef862c829f4 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 3 Jul 2025 16:55:47 -0700 Subject: [PATCH 10/13] test --- examples/apps/flux_demo.py | 2 +- .../dynamo/conversion/_TRTInterpreter.py | 4 +-- tools/llm/torchtrt_ext/register_sdpa.py | 31 ++++++++++++------- tools/perf/Flux/flux_perf.py | 2 +- 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index fbc44e0f6f..42c314c560 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -272,7 +272,7 @@ def main(args): parser.add_argument( "--fp4_mha", action="store_true", - help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_FP8_MHA_CONFIG", + help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_DEFAULT_CFG", ) parser.add_argument( "--low_vram_mode", diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 50b83c1f42..8d7a914836 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -220,8 +220,8 @@ def _populate_trt_builder_config( if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( trt.ProfilingVerbosity.DETAILED - # if self._debugger_config and self._debugger_config.save_engine_profile - # else trt.ProfilingVerbosity.LAYER_NAMES_ONLY + if self._debugger_config and self._debugger_config.save_engine_profile + else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) if version.parse(trt.__version__) >= version.parse("8.6"): diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index 7436f31939..90a00a5798 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -4,7 +4,6 @@ from typing import Callable, Sequence, Tuple import torch -from sdpa_converter import * from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check from torch_tensorrt.dynamo.lowering import TORCH_TRT_DECOMPOSITIONS @@ -15,15 +14,19 @@ clean_up_graph_after_modifications, ) +from .sdpa_converter import * + logger = logging.getLogger(__name__) # Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention # This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it. -TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default) +TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default, None) +TORCH_TRT_DECOMPOSITIONS.pop( + torch.ops.aten._scaled_dot_product_efficient_attention.default, None +) TORCH_TRT_DECOMPOSITIONS.pop( - torch.ops.aten._scaled_dot_product_efficient_attention.default + torch.ops.aten._scaled_dot_product_flash_attention.default, None ) -TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten._scaled_dot_product_flash_attention.default) REPLACEABLE_ATEN_OPS = { torch.ops.aten._scaled_dot_product_efficient_attention.default, @@ -59,6 +62,7 @@ def replace_variants_of_sdpa( elif len(node.args) == 5: query, key, value, attn_mask, is_causal = node.args dropout_p = 0.0 + else: raise ValueError( f"Unexpected number of arguments for {node.target} in the graph" @@ -71,6 +75,8 @@ def replace_variants_of_sdpa( query, key, value, dropout_p, is_causal, return_debug_mask = ( node.args ) + if len(node.args) == 5: + query, key, value, dropout_p, is_causal = node.args elif len(node.args) == 3: query, key, value = node.args dropout_p = 0.0 @@ -79,20 +85,21 @@ def replace_variants_of_sdpa( raise ValueError( f"Unexpected number of arguments for {node.target} in the graph" ) - if attn_mask is not None: - logger.warning( - f"This current version of SDPA converter does not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration." - ) - - modified_input_args = (query, key, value, None, dropout_p, is_causal) + logger.warning( + f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations." + ) + modified_input_args = (query, key, value, None, dropout_p, True) # Create a new node with torch.nn.functional.scaled_dot_product_attention # The input args is (query, key, value, is_causal). kwargs has scale with gm.graph.inserting_after(node): new_node = gm.graph.call_function( torch.nn.functional.scaled_dot_product_attention, args=modified_input_args, - kwargs={"scale": node.kwargs.get("scale", None)}, + kwargs={ + "scale": node.kwargs.get("scale", None), + "use_fp32_acc": settings.use_fp32_acc, + }, ) # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. @@ -113,7 +120,7 @@ def replace_variants_of_sdpa( # Clean up the graph clean_up_graph_after_modifications(gm) - logger.info( + logger.debug( "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention" ) return gm diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py index c654b06d9b..37da4773c6 100644 --- a/tools/perf/Flux/flux_perf.py +++ b/tools/perf/Flux/flux_perf.py @@ -73,7 +73,7 @@ def main(args): parser.add_argument( "--fp4_mha", action="store_true", - help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_FP8_MHA_CONFIG", + help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_DEFAULT_CFG", ) parser.add_argument( "--low_vram_mode", From 2f4afae992d948e28d0872863ee7e7c83dfcc233 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 3 Jul 2025 17:15:07 -0700 Subject: [PATCH 11/13] test --- examples/apps/flux_demo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index 42c314c560..d3bbe6a38d 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -10,6 +10,7 @@ import torch_tensorrt from accelerate.hooks import remove_hook_from_module from diffusers import FluxPipeline +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel DEVICE = "cuda:0" From aa8ea5d533ee0a893d30a997682a6cfad295c61f Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 3 Jul 2025 20:25:17 -0700 Subject: [PATCH 12/13] test --- examples/apps/flux_demo.py | 8 ++++--- .../lowering/passes/constant_folding.py | 4 ++++ tools/perf/Flux/flux_perf.py | 22 ++++++++++++++++--- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index d3bbe6a38d..761b1bafa8 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -135,14 +135,16 @@ def forward_loop(mod): pipe.transformer = trt_gm seed = 42 image = pipe( - ["Beach and Kids"], + [ + "enchanted winter forest, soft diffuse light on a snow-filled day, serene nature scene, the forest is illuminated by the snow" + ], output_type="pil", - num_inference_steps=20, + num_inference_steps=30, num_images_per_prompt=batch_size, generator=torch.Generator("cuda").manual_seed(seed), ).images print(f"generated {len(image)} images") - image[0].save("beach_kids.png") + image[0].save("forest.png") torch.cuda.empty_cache() diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 928b7284fe..bd3baeebb7 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -106,7 +106,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: import modelopt.torch.quantization as mtq assert torch.ops.tensorrt.quantize_op.default + assert torch.ops.tensorrt.dynamic_block_quantize_op.default self.quantization_ops.add(torch.ops.tensorrt.quantize_op.default) + self.quantization_ops.add( + torch.ops.tensorrt.dynamic_block_quantize_op.default + ) except Exception as e: pass diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py index 37da4773c6..73632ff82d 100644 --- a/tools/perf/Flux/flux_perf.py +++ b/tools/perf/Flux/flux_perf.py @@ -9,6 +9,24 @@ from flux_demo import compile_model +def profile(pipe, prompt, inference_step, batch_size=1): + print(f"Running torch profiler with {inference_step=} {batch_size=}") + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], + record_shapes=True, + profile_memory=True, + with_stack=True, + ) as prof: + with torch.profiler.record_function("model_inference"): + pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100)) + + def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): print(f"Running warmup with {batch_size=} {inference_step=} iterations=10") # warmup @@ -41,9 +59,6 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): "Average Latency Per Step:", (end - start) / inference_step / iterations / batch_size, ) - - # run the perf tool - print(f"Running cudart perf tool with {inference_step=} {batch_size=}") return @@ -52,6 +67,7 @@ def main(args): pipe, backbone, trt_gm = compile_model(args) benchmark(pipe, ["Test"], 20, batch_size=args.max_batch_size, iterations=3) + # profile(pipe, ["enchanted winter forest, soft diffuse light on a snow-filled day, serene nature scene, the forest is illuminated by the snow"], 20, batch_size=args.max_batch_size) if __name__ == "__main__": From aff2d9dace812890e058f25dc046d88fb26e2eb7 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 10 Jul 2025 09:26:09 -0700 Subject: [PATCH 13/13] remove profile --- tools/perf/Flux/flux_perf.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py index 73632ff82d..1d3b2acbbc 100644 --- a/tools/perf/Flux/flux_perf.py +++ b/tools/perf/Flux/flux_perf.py @@ -9,24 +9,6 @@ from flux_demo import compile_model -def profile(pipe, prompt, inference_step, batch_size=1): - print(f"Running torch profiler with {inference_step=} {batch_size=}") - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CUDA], - record_shapes=True, - profile_memory=True, - with_stack=True, - ) as prof: - with torch.profiler.record_function("model_inference"): - pipe( - prompt, - output_type="pil", - num_inference_steps=inference_step, - num_images_per_prompt=batch_size, - ).images - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100)) - - def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): print(f"Running warmup with {batch_size=} {inference_step=} iterations=10") # warmup @@ -67,7 +49,6 @@ def main(args): pipe, backbone, trt_gm = compile_model(args) benchmark(pipe, ["Test"], 20, batch_size=args.max_batch_size, iterations=3) - # profile(pipe, ["enchanted winter forest, soft diffuse light on a snow-filled day, serene nature scene, the forest is illuminated by the snow"], 20, batch_size=args.max_batch_size) if __name__ == "__main__":