dFlow is a PyTorch library for diffusion models with dispersive loss regularization. This repo contains the official PyTorch implementation of Dispersive Loss.
Diffuse and Disperse: Image Generation with Representation Regularization
Runqian Wang, Kaiming He
MIT
We propose Dispersive Loss, a simple plug-and-play regularizer that effectively improves diffusion-based generative models. Our loss function encourages internal representations to disperse in the hidden space, analogous to contrastive self-supervised learning, with the key distinction that it requires no positive sample pairs and therefore does not interfere with the sampling process used for regression.
Install dFlow directly from the repository:
git clone https://github.com/raywang4/DispLoss.git
cd DispLoss
pip install -e .
Or install with development dependencies:
pip install -e ".[dev]"
Make sure you run
export NCCL_P2P_DISABLE=1
import dflow
from dflow.models import SiT
from dflow.configs import get_model_config
from dflow.training import train_with_dispersive_loss
from dflow.sampling import sample_ode
# Create a model
config = get_model_config('SiT-XL/2')
model = SiT(**config)
# Train with dispersive loss
train_config = get_training_config('imagenet_256')
loss_config = get_dispersive_loss_config('default')
# Sample images
images = sample_ode(model, shape=(4, 3, 256, 256), num_steps=100)
dFlow/
βββ dflow/ # Main library package
β βββ __init__.py # Library exports
β βββ models.py # SiT model implementations
β βββ transport.py # Transport-based diffusion methods
β βββ training.py # Training utilities and dispersive loss
β βββ sampling.py # ODE/SDE sampling methods
β βββ configs.py # Configuration management
βββ configs/ # Configuration files
β βββ models/ # Model configurations
β βββ sit_configs.py # SiT model variants
βββ transport/ # Transport module (original)
βββ setup.py # Installation script
βββ pyproject.toml # Modern Python packaging
βββ requirements.txt # Dependencies
βββ README.md # This file
The core implementation of Dispersive Loss is highlighted below:
def disp_loss(self, z): # Dispersive Loss implementation (InfoNCE-L2 variant)
z = z.reshape((z.shape[0],-1)) # flatten
diff = th.nn.functional.pdist(z).pow(2)/z.shape[1] # pairwise distance
diff = th.concat((diff, diff, th.zeros(z.shape[0]).cuda())) # match JAX implementation of full BxB matrix
return th.log(th.exp(-diff).mean()) # calculate loss
To train with Dispersive Loss, simply add the --disp
argument to the training script:
torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --disp
from dflow.training import train_with_dispersive_loss, DispersiveLoss
from dflow.configs import get_model_config, get_training_config
# Get configurations
model_config = get_model_config('SiT-XL/2')
train_config = get_training_config('imagenet_256')
# Create model and loss
model = SiT(**model_config)
dispersive_loss = DispersiveLoss(lambda_disp=0.25, tau=1.0)
# Training loop
for epoch in range(train_config['num_epochs']):
avg_loss = train_with_dispersive_loss(model, dataloader, optimizer, device, train_config)
print(f"Epoch {epoch}, Loss: {avg_loss:.4f}")
Logging. To enable wandb
, firstly set WANDB_KEY
, ENTITY
, and PROJECT
as environment variables:
export WANDB_KEY="key"
export ENTITY="entity name"
export PROJECT="project name"
Then in training command add the --wandb
flag:
torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --disp --wandb
Resume training. To resume training from custom checkpoint:
torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-L/2 --data-path /path/to/imagenet/train --disp --ckpt /path/to/model.pt
Pre-trained checkpoints. We provide a SiT-B/2 checkpoint and a SiT-XL/2 checkpoint both trained with Dispersive Loss for 80 epochs on ImageNet 256x256.
Sampling from checkpoint. To sample from the EMA weights of a 256x256 SiT-XL/2 model checkpoint with ODE sampler, run:
python sample.py ODE --model SiT-XL/2 --image-size 256 --ckpt /path/to/model.pt
from dflow.sampling import sample_ode, sample_sde
# ODE sampling
images = sample_ode(
model,
shape=(4, 3, 256, 256),
num_steps=100,
cfg_scale=1.5
)
# SDE sampling
images = sample_sde(
model,
shape=(4, 3, 256, 256),
num_steps=100,
cfg_scale=1.5
)
More sampling options. For more sampling options such as SDE sampling, please refer to train_utils.py
.
The sample_ddp.py
script samples a large number of images from a pre-trained model in parallel. This script
generates a folder of samples as well as a .npz
file which can be directly used with ADM's TensorFlow
evaluation suite to compute FID, Inception Score and
other metrics. To sample 50K images from a pre-trained SiT-XL/2 model over N
GPUs under default ODE sampler settings, run:
torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py ODE --model SiT-XL/2 --num-fid-samples 50000 --ckpt /path/to/model.pt
dFlow provides flexible configuration management:
from dflow.configs import get_model_config, get_training_config, get_dispersive_loss_config
# Model configurations
model_config = get_model_config('SiT-XL/2') # Available: SiT-S/2, SiT-B/2, SiT-L/2, SiT-XL/2, etc.
# Training configurations
train_config = get_training_config('imagenet_256')
# Dispersive loss configurations
loss_config = get_dispersive_loss_config('default') # or 'jax_reproduction'
Our original implementation is in JAX, and this repo contains our re-implementation in PyTorch. Therefore, results from running this repo may have minor numerical differences with those reported in our paper. In our JAX experiments, we used 16 devices with local batch size 16, whereas in PyTorch experiments we used 8 devices with local batch size 32. We have adjusted the hyperparameter choices slightly for best performance. We report our reproduction results below.
implementation | config | local bz | B/2 80 ep | XL/2 80 ep (cfg=1.5) |
---|---|---|---|---|
baseline | - | 16 | 36.49 | 6.02 |
JAX |
|
16 | 32.35 | 5.09 |
PyTorch |
|
32 | 32.64 | 4.74 |
git clone https://github.com/raywang4/DispLoss.git
cd DispLoss
pip install -e ".[dev]"
pre-commit install
pytest tests/
black dflow/
flake8 dflow/
This project is under the MIT license. See LICENSE for details.
Contributions are welcome! Please feel free to submit a Pull Request.
If you use this code in your research, please cite our paper:
@article{wang2024diffuse,
title={Diffuse and Disperse: Image Generation with Representation Regularization},
author={Wang, Runqian and He, Kaiming},
journal={arXiv preprint arXiv:2506.09027},
year={2024}
}