Skip to content

sunilhoho/dFlow

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

22 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

dFlow: Dispersive Flow

dFlow is a PyTorch library for diffusion models with dispersive loss regularization. This repo contains the official PyTorch implementation of Dispersive Loss.

Diffuse and Disperse: Image Generation with Representation Regularization
Runqian Wang, Kaiming He
MIT

We propose Dispersive Loss, a simple plug-and-play regularizer that effectively improves diffusion-based generative models. Our loss function encourages internal representations to disperse in the hidden space, analogous to contrastive self-supervised learning, with the key distinction that it requires no positive sample pairs and therefore does not interfere with the sampling process used for regression.

πŸš€ Quick Start

Installation

Install dFlow directly from the repository:

git clone https://github.com/raywang4/DispLoss.git
cd DispLoss
pip install -e .

Or install with development dependencies:

pip install -e ".[dev]"

Key:

Make sure you run

export NCCL_P2P_DISABLE=1

Basic Usage

import dflow
from dflow.models import SiT
from dflow.configs import get_model_config
from dflow.training import train_with_dispersive_loss
from dflow.sampling import sample_ode

# Create a model
config = get_model_config('SiT-XL/2')
model = SiT(**config)

# Train with dispersive loss
train_config = get_training_config('imagenet_256')
loss_config = get_dispersive_loss_config('default')

# Sample images
images = sample_ode(model, shape=(4, 3, 256, 256), num_steps=100)

πŸ“š Library Structure

dFlow/
β”œβ”€β”€ dflow/                    # Main library package
β”‚   β”œβ”€β”€ __init__.py          # Library exports
β”‚   β”œβ”€β”€ models.py            # SiT model implementations
β”‚   β”œβ”€β”€ transport.py         # Transport-based diffusion methods
β”‚   β”œβ”€β”€ training.py          # Training utilities and dispersive loss
β”‚   β”œβ”€β”€ sampling.py          # ODE/SDE sampling methods
β”‚   └── configs.py           # Configuration management
β”œβ”€β”€ configs/                 # Configuration files
β”‚   └── models/              # Model configurations
β”‚       └── sit_configs.py   # SiT model variants
β”œβ”€β”€ transport/               # Transport module (original)
β”œβ”€β”€ setup.py                 # Installation script
β”œβ”€β”€ pyproject.toml          # Modern Python packaging
β”œβ”€β”€ requirements.txt         # Dependencies
└── README.md               # This file

πŸ”§ Core Implementation

The core implementation of Dispersive Loss is highlighted below:

def disp_loss(self, z): # Dispersive Loss implementation (InfoNCE-L2 variant)
  z = z.reshape((z.shape[0],-1)) # flatten
  diff = th.nn.functional.pdist(z).pow(2)/z.shape[1] # pairwise distance
  diff = th.concat((diff, diff, th.zeros(z.shape[0]).cuda()))  # match JAX implementation of full BxB matrix
  return th.log(th.exp(-diff).mean()) # calculate loss

πŸƒβ€β™‚οΈ Training

Command Line Training

To train with Dispersive Loss, simply add the --disp argument to the training script:

torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --disp

Programmatic Training

from dflow.training import train_with_dispersive_loss, DispersiveLoss
from dflow.configs import get_model_config, get_training_config

# Get configurations
model_config = get_model_config('SiT-XL/2')
train_config = get_training_config('imagenet_256')

# Create model and loss
model = SiT(**model_config)
dispersive_loss = DispersiveLoss(lambda_disp=0.25, tau=1.0)

# Training loop
for epoch in range(train_config['num_epochs']):
    avg_loss = train_with_dispersive_loss(model, dataloader, optimizer, device, train_config)
    print(f"Epoch {epoch}, Loss: {avg_loss:.4f}")

Logging. To enable wandb, firstly set WANDB_KEY, ENTITY, and PROJECT as environment variables:

export WANDB_KEY="key"
export ENTITY="entity name"
export PROJECT="project name"

Then in training command add the --wandb flag:

torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --disp --wandb

Resume training. To resume training from custom checkpoint:

torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-L/2 --data-path /path/to/imagenet/train --disp --ckpt /path/to/model.pt

🎨 Sampling

Command Line Sampling

Pre-trained checkpoints. We provide a SiT-B/2 checkpoint and a SiT-XL/2 checkpoint both trained with Dispersive Loss for 80 epochs on ImageNet 256x256.

Sampling from checkpoint. To sample from the EMA weights of a 256x256 SiT-XL/2 model checkpoint with ODE sampler, run:

python sample.py ODE --model SiT-XL/2 --image-size 256 --ckpt /path/to/model.pt

Programmatic Sampling

from dflow.sampling import sample_ode, sample_sde

# ODE sampling
images = sample_ode(
    model, 
    shape=(4, 3, 256, 256), 
    num_steps=100, 
    cfg_scale=1.5
)

# SDE sampling
images = sample_sde(
    model,
    shape=(4, 3, 256, 256),
    num_steps=100,
    cfg_scale=1.5
)

More sampling options. For more sampling options such as SDE sampling, please refer to train_utils.py.

πŸ“Š Evaluation

The sample_ddp.py script samples a large number of images from a pre-trained model in parallel. This script generates a folder of samples as well as a .npz file which can be directly used with ADM's TensorFlow evaluation suite to compute FID, Inception Score and other metrics. To sample 50K images from a pre-trained SiT-XL/2 model over N GPUs under default ODE sampler settings, run:

torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py ODE --model SiT-XL/2 --num-fid-samples 50000 --ckpt /path/to/model.pt

βš™οΈ Configuration

dFlow provides flexible configuration management:

from dflow.configs import get_model_config, get_training_config, get_dispersive_loss_config

# Model configurations
model_config = get_model_config('SiT-XL/2')  # Available: SiT-S/2, SiT-B/2, SiT-L/2, SiT-XL/2, etc.

# Training configurations  
train_config = get_training_config('imagenet_256')

# Dispersive loss configurations
loss_config = get_dispersive_loss_config('default')  # or 'jax_reproduction'

πŸ”¬ Differences from JAX

Our original implementation is in JAX, and this repo contains our re-implementation in PyTorch. Therefore, results from running this repo may have minor numerical differences with those reported in our paper. In our JAX experiments, we used 16 devices with local batch size 16, whereas in PyTorch experiments we used 8 devices with local batch size 32. We have adjusted the hyperparameter choices slightly for best performance. We report our reproduction results below.

implementation config local bz B/2 80 ep XL/2 80 ep (cfg=1.5)
baseline - 16 36.49 6.02
JAX $\lambda$=0.5, $\tau$=0.5, depth=num_layers//4 16 32.35 5.09
PyTorch $\lambda$=0.25, $\tau$=1, depth=num_layers 32 32.64 4.74

πŸ“¦ Development

Setup Development Environment

git clone https://github.com/raywang4/DispLoss.git
cd DispLoss
pip install -e ".[dev]"
pre-commit install

Running Tests

pytest tests/

Code Formatting

black dflow/
flake8 dflow/

πŸ“„ License

This project is under the MIT license. See LICENSE for details.

🀝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

πŸ“– Citation

If you use this code in your research, please cite our paper:

@article{wang2024diffuse,
  title={Diffuse and Disperse: Image Generation with Representation Regularization},
  author={Wang, Runqian and He, Kaiming},
  journal={arXiv preprint arXiv:2506.09027},
  year={2024}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •