From 8aca9ee0a83ccfd0b13557d81b6966077910068b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 4 Jul 2025 07:51:42 +0530 Subject: [PATCH 01/10] feat: enable i2i fine-tuning in Kontext script. --- .../train_dreambooth_lora_flux_kontext.py | 188 ++++++++++++++---- 1 file changed, 149 insertions(+), 39 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index 9f97567b06b8..58a48fb19112 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -40,7 +40,7 @@ from torch.utils.data import Dataset from torch.utils.data.sampler import BatchSampler from torchvision import transforms -from torchvision.transforms.functional import crop +from torchvision.transforms import functional as TF from tqdm.auto import tqdm from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast @@ -62,11 +62,7 @@ free_memory, parse_buckets_string, ) -from diffusers.utils import ( - check_min_version, - convert_unet_state_dict_to_peft, - is_wandb_available, -) +from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, is_wandb_available, load_image from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import is_compiled_module @@ -186,6 +182,7 @@ def log_validation( ) pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) + pipeline_args_cp = pipeline_args.copy() # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None @@ -193,14 +190,16 @@ def log_validation( # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast with torch.no_grad(): - prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( - pipeline_args["prompt"], prompt_2=pipeline_args["prompt"] - ) + prompt = pipeline_args_cp.pop("prompt") + prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt, prompt_2=None) images = [] for _ in range(args.num_validation_images): with autocast_ctx: image = pipeline( - prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator + **pipeline_args_cp, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + generator=generator, ).images[0] images.append(image) @@ -310,6 +309,12 @@ def parse_args(input_args=None): "default, the standard Image Dataset maps out 'file_name' " "to 'image'.", ) + parser.add_argument( + "--cond_image_column", + type=str, + default="image", + help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning", + ) parser.add_argument( "--caption_column", type=str, @@ -351,6 +356,12 @@ def parse_args(input_args=None): default=None, help="A prompt that is used during validation to verify that the model is learning.", ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + help="Validation image to use (during I2I fine-tuning) to verify that the model is learning.", + ) parser.add_argument( "--num_validation_images", type=int, @@ -399,7 +410,7 @@ def parse_args(input_args=None): parser.add_argument( "--output_dir", type=str, - default="flux-dreambooth-lora", + default="flux-kontext-lora", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") @@ -716,6 +727,8 @@ def parse_args(input_args=None): raise ValueError("You must specify a data directory for class images.") if args.class_prompt is None: raise ValueError("You must specify prompt for class images.") + if args.cond_image_column is not None: + raise ValueError("Prior preservation isn't supported with I2I training.") else: # logger is not available yet if args.class_data_dir is not None: @@ -723,6 +736,11 @@ def parse_args(input_args=None): if args.class_prompt is not None: warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + if args.cond_image_column is not None: + assert args.image_column is not None + assert args.caption_column is not None + assert args.dataset_name is not None + return args @@ -742,6 +760,7 @@ def __init__( repeats=1, center_crop=False, buckets=None, + args=None, ): self.center_crop = center_crop @@ -774,6 +793,7 @@ def __init__( column_names = dataset["train"].column_names # 6. Get the column names for input/target. + # TODO: add validation for `cond_image_column` if args.image_column is None: image_column = column_names[0] logger.info(f"image column defaulting to {image_column}") @@ -783,7 +803,12 @@ def __init__( raise ValueError( f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) - instance_images = dataset["train"][image_column] + instance_images = [dataset["train"][i][image_column] for i in range(len(dataset["train"]))] + cond_images = None + cond_image_column = args.cond_image_column + if cond_image_column is not None: + cond_images = [dataset["train"][i][cond_image_column] for i in range(len(dataset["train"]))] + assert len(instance_images) == len(cond_images) if args.caption_column is None: logger.info( @@ -811,14 +836,23 @@ def __init__( self.custom_instance_prompts = None self.instance_images = [] - for img in instance_images: + self.cond_images = [] + for i, img in enumerate(instance_images): self.instance_images.extend(itertools.repeat(img, repeats)) + if cond_images is not None: + self.cond_images.extend(itertools.repeat(cond_images[i], repeats)) self.pixel_values = [] - for image in self.instance_images: + self.cond_pixel_values = [] + for i, image in enumerate(self.instance_images): image = exif_transpose(image) if not image.mode == "RGB": image = image.convert("RGB") + dest_image = None + if self.cond_images: + dest_image = exif_transpose(self.cond_images[i]) + if not dest_image.mode == "RGB": + dest_image = dest_image.convert("RGB") width, height = image.size @@ -828,25 +862,16 @@ def __init__( self.size = (target_height, target_width) # based on the bucket assignment, define the transformations - train_resize = transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR) - train_crop = transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size) - train_flip = transforms.RandomHorizontalFlip(p=1.0) - train_transforms = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] + image, dest_image = self.paired_transform( + image, + dest_image=dest_image, + size=self.size, + center_crop=args.center_crop, + random_flip=args.random_flip, ) - image = train_resize(image) - if args.center_crop: - image = train_crop(image) - else: - y1, x1, h, w = train_crop.get_params(image, self.size) - image = crop(image, y1, x1, h, w) - if args.random_flip and random.random() < 0.5: - image = train_flip(image) - image = train_transforms(image) self.pixel_values.append((image, bucket_idx)) + if dest_image: + self.cond_pixel_values.append((dest_image, bucket_idx)) self.num_instance_images = len(self.instance_images) self._length = self.num_instance_images @@ -880,6 +905,9 @@ def __getitem__(self, index): instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] example["instance_images"] = instance_image example["bucket_idx"] = bucket_idx + if self.cond_pixel_values: + dest_image, _ = self.cond_pixel_values[index % self.num_instance_images] + example["cond_images"] = dest_image if self.custom_instance_prompts: caption = self.custom_instance_prompts[index % self.num_instance_images] @@ -902,6 +930,43 @@ def __getitem__(self, index): return example + def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False): + # 1. Resize (deterministic) + resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + image = resize(image) + if dest_image is not None: + dest_image = resize(dest_image) + + # 2. Crop: either center or SAME random crop + if center_crop: + crop = transforms.CenterCrop(size) + image = crop(image) + if dest_image is not None: + dest_image = crop(dest_image) + else: + # get_params returns (i, j, h, w) + i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size) + image = TF.crop(image, i, j, h, w) + if dest_image is not None: + dest_image = TF.crop(dest_image, i, j, h, w) + + # 3. Random horizontal flip with the SAME coin flip + if random_flip: + do_flip = random.random() < 0.5 + if do_flip: + image = TF.hflip(image) + if dest_image is not None: + dest_image = TF.hflip(dest_image) + + # 4. ToTensor + Normalize (deterministic) + to_tensor = transforms.ToTensor() + normalize = transforms.Normalize([0.5], [0.5]) + image = normalize(to_tensor(image)) + if dest_image is not None: + dest_image = normalize(to_tensor(dest_image)) + + return (image, dest_image) if dest_image is not None else image + def collate_fn(examples, with_prior_preservation=False): pixel_values = [example["instance_images"] for example in examples] @@ -917,6 +982,11 @@ def collate_fn(examples, with_prior_preservation=False): pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() batch = {"pixel_values": pixel_values, "prompts": prompts} + if any("cond_images" in example for example in examples): + cond_pixel_values = [example["cond_images"] for example in examples] + cond_pixel_values = torch.stack(cond_pixel_values) + cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float() + batch.update({"cond_pixel_values": cond_pixel_values}) return batch @@ -1534,6 +1604,7 @@ def load_model_hook(models, input_dir): buckets=buckets, repeats=args.repeats, center_crop=args.center_crop, + args=args, ) batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False) train_dataloader = torch.utils.data.DataLoader( @@ -1608,14 +1679,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): vae_config_shift_factor = vae.config.shift_factor vae_config_scaling_factor = vae.config.scaling_factor vae_config_block_out_channels = vae.config.block_out_channels + has_image_input = args.cond_image_column is not None if args.cache_latents: latents_cache = [] + cond_latents_cache = [] for batch in tqdm(train_dataloader, desc="Caching latents"): with torch.no_grad(): batch["pixel_values"] = batch["pixel_values"].to( accelerator.device, non_blocking=True, dtype=weight_dtype ) latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + if has_image_input: + batch["cond_pixel_values"] = batch["cond_pixel_values"].to( + accelerator.device, non_blocking=True, dtype=weight_dtype + ) + cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist) if args.validation_prompt is None: del vae @@ -1678,7 +1756,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - tracker_name = "dreambooth-flux-dev-lora" + tracker_name = "dreambooth-flux-kontext-lora" accelerator.init_trackers(tracker_name, config=vars(args)) # Train! @@ -1759,6 +1837,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: if not args.train_text_encoder: + # Should find a way to precompute these. prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( prompts, text_encoders, tokenizers ) @@ -1794,16 +1873,27 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.cache_latents: if args.vae_encode_mode == "sample": model_input = latents_cache[step].sample() + if has_image_input: + cond_model_input = cond_latents_cache[step].sample() else: model_input = latents_cache[step].mode() + if has_image_input: + cond_model_input = cond_latents_cache[step].mode() else: pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + if has_image_input: + cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype) if args.vae_encode_mode == "sample": model_input = vae.encode(pixel_values).latent_dist.sample() + cond_model_input = vae.encode(cond_pixel_values).latent_dist.sample() else: model_input = vae.encode(pixel_values).latent_dist.mode() + cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode() model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) + if has_image_input: + cond_model_input = (cond_model_input - vae_config_shift_factor) * vae_config_scaling_factor + cond_model_input = cond_model_input.to(dtype=weight_dtype) vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1) @@ -1814,6 +1904,17 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.device, weight_dtype, ) + if has_image_input: + cond_latents_ids = FluxKontextPipeline._prepare_latent_image_ids( + cond_model_input.shape[0], + cond_model_input.shape[2] // 2, + cond_model_input.shape[3] // 2, + accelerator.device, + weight_dtype, + ) + cond_latents_ids[..., 0] = 1 + latent_image_ids = torch.cat([latent_image_ids, cond_latents_ids], dim=0) + # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) bsz = model_input.shape[0] @@ -1834,7 +1935,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # zt = (1 - texp) * x + texp * z1 sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise - packed_noisy_model_input = FluxKontextPipeline._pack_latents( noisy_model_input, batch_size=model_input.shape[0], @@ -1842,13 +1942,19 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): height=model_input.shape[2], width=model_input.shape[3], ) + if has_image_input: + packed_cond_input = FluxKontextPipeline._pack_latents( + cond_model_input, + batch_size=cond_model_input.shape[0], + num_channels_latents=cond_model_input.shape[1], + height=cond_model_input.shape[2], + width=cond_model_input.shape[3], + ) + packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_input], dim=1) - # handle guidance - if unwrap_model(transformer).config.guidance_embeds: - guidance = torch.tensor([args.guidance_scale], device=accelerator.device) - guidance = guidance.expand(model_input.shape[0]) - else: - guidance = None + # Kontext always has guidance + guidance = torch.tensor([args.guidance_scale], device=accelerator.device) + guidance = guidance.expand(model_input.shape[0]) # Predict the noise residual model_pred = transformer( @@ -1970,6 +2076,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): torch_dtype=weight_dtype, ) pipeline_args = {"prompt": args.validation_prompt} + if has_image_input and args.validation_image: + pipeline_args.update({"image": load_image(args.validation_image)}) images = log_validation( pipeline=pipeline, args=args, @@ -2030,6 +2138,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): images = [] if args.validation_prompt and args.num_validation_images > 0: pipeline_args = {"prompt": args.validation_prompt} + if has_image_input and args.validation_image: + pipeline_args.update({"image": load_image(args.validation_image)}) images = log_validation( pipeline=pipeline, args=args, From dd893b9413b53b37b17634c2dfbc2e3e024a49b7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 4 Jul 2025 09:48:03 +0530 Subject: [PATCH 02/10] readme --- examples/dreambooth/README_flux.md | 35 +++++++++++++++++++ .../train_dreambooth_lora_flux_kontext.py | 32 +++++++++++++---- 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index 24c71d5c569d..3bad4ff87d1c 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -294,6 +294,41 @@ accelerate launch train_dreambooth_lora_flux_kontext.py \ Fine-tuning Kontext on the T2I task can be useful when working with specific styles/subjects where it may not perform as expected. +Image-guided fine-tuning (I2I) is also supported. To start, you must have a dataset containing triplets: + +* Condition image +* Target image +* Instruction + +[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training: + +```bash +accelerate launch train_dreambooth_lora_flux_kontext.py \ + --pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev \ + --output_dir="kontext-i2i" \ + --dataset_name="kontext-community/relighting" \ + --image_column="output" --cond_image_column="file_name" --caption_column="instruction" \ + --mixed_precision="bf16" \ + --resolution=1024 \ + --train_batch_size=1 \ + --guidance_scale=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --optimizer="adamw" \ + --use_8bit_adam \ + --cache_latents \ + --learning_rate=1e-4 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=500 \ + --seed="0" +``` + +More generally, when performing I2I fine-tuning, we expect you to: + +* Have a dataset `kontext-community/relighting` +* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training + ### Misc notes * By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it. diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index 58a48fb19112..9fc3b00fca3f 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -335,7 +335,6 @@ def parse_args(input_args=None): "--instance_prompt", type=str, default=None, - required=True, help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", ) parser.add_argument( @@ -740,6 +739,7 @@ def parse_args(input_args=None): assert args.image_column is not None assert args.caption_column is not None assert args.dataset_name is not None + assert not args.train_text_encoder return args @@ -870,7 +870,7 @@ def __init__( random_flip=args.random_flip, ) self.pixel_values.append((image, bucket_idx)) - if dest_image: + if dest_image is not None: self.cond_pixel_values.append((dest_image, bucket_idx)) self.num_instance_images = len(self.instance_images) @@ -965,7 +965,7 @@ def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop= if dest_image is not None: dest_image = normalize(to_tensor(dest_image)) - return (image, dest_image) if dest_image is not None else image + return (image, dest_image) if dest_image is not None else (image, None) def collate_fn(examples, with_prior_preservation=False): @@ -1606,6 +1606,8 @@ def load_model_hook(models, input_dir): center_crop=args.center_crop, args=args, ) + if args.cond_image_column is not None: + logger.info("I2I fine-tuning enabled.") batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False) train_dataloader = torch.utils.data.DataLoader( train_dataset, @@ -1645,6 +1647,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: + text_encoder_one.cpu(), text_encoder_two.cpu() del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two free_memory() @@ -1676,6 +1679,20 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + elif train_dataset.custom_instance_prompts and not args.train_text_encoder: + cached_text_embeddings = [] + for batch in tqdm(train_dataloader, desc="Embedding prompts"): + batch_prompts = batch["prompts"] + prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( + batch_prompts, text_encoders, tokenizers + ) + cached_text_embeddings.append((prompt_embeds, pooled_prompt_embeds, text_ids)) + + if args.validation_prompt is None: + text_encoder_one.cpu(), text_encoder_two.cpu() + del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two + free_memory() + vae_config_shift_factor = vae.config.shift_factor vae_config_scaling_factor = vae.config.scaling_factor vae_config_block_out_channels = vae.config.block_out_channels @@ -1696,6 +1713,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist) if args.validation_prompt is None: + vae.cpu() del vae free_memory() @@ -1837,10 +1855,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: if not args.train_text_encoder: - # Should find a way to precompute these. - prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( - prompts, text_encoders, tokenizers - ) + prompt_embeds, pooled_prompt_embeds, text_ids = cached_text_embeddings[step] else: tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) tokens_two = tokenize_prompt( @@ -1942,6 +1957,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): height=model_input.shape[2], width=model_input.shape[3], ) + orig_inp_shape = packed_noisy_model_input.shape if has_image_input: packed_cond_input = FluxKontextPipeline._pack_latents( cond_model_input, @@ -1968,6 +1984,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): img_ids=latent_image_ids, return_dict=False, )[0] + if has_image_input: + model_pred = model_pred[:, : orig_inp_shape[1]] model_pred = FluxKontextPipeline._unpack_latents( model_pred, height=model_input.shape[2] * vae_scale_factor, From bce55a9d1f4858c78036156fc461f022613c3b3d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 5 Jul 2025 08:58:46 +0530 Subject: [PATCH 03/10] more checks. --- examples/dreambooth/train_dreambooth_lora_flux_kontext.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index 9fc3b00fca3f..81ee51d9f439 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -740,6 +740,8 @@ def parse_args(input_args=None): assert args.caption_column is not None assert args.dataset_name is not None assert not args.train_text_encoder + if args.validation_prompt is not None: + assert args.validation_image is None and os.path.exists(args.validation_image) return args From 28ff59d608df95052f35867f5736f31d28e1681e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 7 Jul 2025 19:15:47 +0530 Subject: [PATCH 04/10] Apply suggestions from code review Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> --- examples/dreambooth/README_flux.md | 5 +++-- examples/dreambooth/train_dreambooth_lora_flux_kontext.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index 3bad4ff87d1c..85ac216bd680 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -319,8 +319,9 @@ accelerate launch train_dreambooth_lora_flux_kontext.py \ --cache_latents \ --learning_rate=1e-4 \ --lr_scheduler="constant" \ - --lr_warmup_steps=0 \ - --max_train_steps=500 \ + --lr_warmup_steps=200 \ + --max_train_steps=1000 \ + --rank=16\ --seed="0" ``` diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index 81ee51d9f439..5a6789580ea0 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -795,7 +795,8 @@ def __init__( column_names = dataset["train"].column_names # 6. Get the column names for input/target. - # TODO: add validation for `cond_image_column` +if args.cond_image_column is not None and not args.cond_image_column in column_names: + raise ValueError(f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}") if args.image_column is None: image_column = column_names[0] logger.info(f"image column defaulting to {image_column}") From ba8347dcf50ab426ce0171b1b54bda5f4bb8c0a9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 7 Jul 2025 19:21:10 +0530 Subject: [PATCH 05/10] fixes --- examples/dreambooth/train_dreambooth_lora_flux_kontext.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index 5a6789580ea0..b7930684fb03 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -795,8 +795,10 @@ def __init__( column_names = dataset["train"].column_names # 6. Get the column names for input/target. -if args.cond_image_column is not None and not args.cond_image_column in column_names: - raise ValueError(f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}") + if args.cond_image_column is not None and args.cond_image_column not in column_names: + raise ValueError( + f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) if args.image_column is None: image_column = column_names[0] logger.info(f"image column defaulting to {image_column}") From 73e9d054b23d7f2b23ec47a3c779a33e7dd8c39f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 7 Jul 2025 19:53:02 +0530 Subject: [PATCH 06/10] fix --- examples/dreambooth/train_dreambooth_lora_flux_kontext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index b7930684fb03..af9960c4c08b 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -312,7 +312,7 @@ def parse_args(input_args=None): parser.add_argument( "--cond_image_column", type=str, - default="image", + default=None, help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning", ) parser.add_argument( From 53bf41d2f1bf85815f25f787e8c6307f5cd8211a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 8 Jul 2025 14:14:13 +0530 Subject: [PATCH 07/10] add proj_mlp to the mix --- examples/dreambooth/train_dreambooth_lora_flux_kontext.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index af9960c4c08b..95fe4b7ff5ae 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -1393,6 +1393,7 @@ def main(args): "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2", + "proj_mlp", ] # now we will add new LoRA weights the transformer layers From 9ac1bde87ef34af05b8940de5889f3fcd05956ac Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Tue, 8 Jul 2025 16:19:23 +0300 Subject: [PATCH 08/10] Update README_flux.md add note on installing from commit `05e7a854d0a5661f5b433f6dd5954c224b104f0b` --- examples/dreambooth/README_flux.md | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index 85ac216bd680..18273746c283 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -263,9 +263,19 @@ This reduces memory requirements significantly w/o a significant quality loss. N ## Training Kontext [Kontext](https://bfl.ai/announcements/flux-1-kontext) lets us perform image editing as well as image generation. Even though it can accept both image and text as inputs, one can use it for text-to-image (T2I) generation, too. We -provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for T2I. The optimizations discussed above apply this script, too. +provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for both T2I and I2I. The optimizations discussed above apply this script, too. -Make sure to follow the [instructions to set up your environment](#running-locally-with-pytorch) before proceeding to the rest of the section. +**important** + +> [!NOTE] +> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below. +> To do this, execute the following steps in a new virtual environment: +> ``` +> git clone https://github.com/huggingface/diffusers +> cd diffusers +> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b +> pip install -e . +> ``` Below is an example training command: @@ -343,4 +353,4 @@ To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a Since Flux Kontext finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗 ## Other notes -Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️ \ No newline at end of file +Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️ From d2931e02ce124ff852e1f60613e3f85d41abdd64 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 8 Jul 2025 19:37:20 +0530 Subject: [PATCH 09/10] fix --- examples/dreambooth/train_dreambooth_lora_flux_kontext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index 95fe4b7ff5ae..e7b62c70a850 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -844,7 +844,7 @@ def __init__( self.cond_images = [] for i, img in enumerate(instance_images): self.instance_images.extend(itertools.repeat(img, repeats)) - if cond_images is not None: + if args.dataset_name is not None and cond_images is not None: self.cond_images.extend(itertools.repeat(cond_images[i], repeats)) self.pixel_values = [] From 2b4827ff345f4b5683d02e7fed7b7985fc1ca837 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 8 Jul 2025 20:35:44 +0530 Subject: [PATCH 10/10] fix --- .../train_dreambooth_lora_flux_kontext.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index e7b62c70a850..5bd9b8684d42 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -1844,6 +1844,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigma = sigma.unsqueeze(-1) return sigma + has_guidance = unwrap_model(transformer).config.guidance_embeds for epoch in range(first_epoch, args.num_train_epochs): transformer.train() if args.train_text_encoder: @@ -1906,10 +1907,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype) if args.vae_encode_mode == "sample": model_input = vae.encode(pixel_values).latent_dist.sample() - cond_model_input = vae.encode(cond_pixel_values).latent_dist.sample() + if has_image_input: + cond_model_input = vae.encode(cond_pixel_values).latent_dist.sample() else: model_input = vae.encode(pixel_values).latent_dist.mode() - cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode() + if has_image_input: + cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode() model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) if has_image_input: @@ -1975,8 +1978,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_input], dim=1) # Kontext always has guidance - guidance = torch.tensor([args.guidance_scale], device=accelerator.device) - guidance = guidance.expand(model_input.shape[0]) + guidance = None + if has_guidance: + guidance = torch.tensor([args.guidance_scale], device=accelerator.device) + guidance = guidance.expand(model_input.shape[0]) # Predict the noise residual model_pred = transformer(