Skip to content

Commit ce9add3

Browse files
authored
fix/optimization (#8)
* Fix in a more durable way the optimization problems due to hooking and unhooking * add doctstring repa * use loss name when logging on wandb
1 parent 0b68956 commit ce9add3

File tree

5 files changed

+146
-120
lines changed

5 files changed

+146
-120
lines changed

src/diffulab/diffuse/modelizations/flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def compute_loss(
247247
# Compute extra losses if any
248248
for extra_loss in extra_losses:
249249
e_loss = cast(Tensor, extra_loss(**extra_args))
250-
loss_dict[extra_loss.__class__.__name__] = e_loss
250+
loss_dict[extra_loss.name] = e_loss
251251
return loss_dict
252252

253253
def add_noise(self, x: Tensor, timesteps: Tensor, noise: Tensor | None = None) -> tuple[Tensor, Tensor]:

src/diffulab/diffuse/modelizations/gaussian_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ def compute_loss(
672672
loss_dict = {"loss": loss}
673673
for extra_loss in extra_losses:
674674
e_loss = cast(Tensor, extra_loss(**extra_args))
675-
loss_dict[extra_loss.__class__.__name__] = e_loss
675+
loss_dict[extra_loss.name] = e_loss
676676
return loss_dict
677677

678678
def add_noise(self, x: Tensor, timesteps: Tensor, noise: Tensor | None = None) -> tuple[Tensor, Tensor]:

src/diffulab/training/losses/common.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,18 @@
11
from abc import ABC
2-
from pathlib import Path
32
from typing import TYPE_CHECKING
43

54
import torch.nn as nn
65

76
if TYPE_CHECKING:
8-
from accelerate import Accelerator # type: ignore
9-
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
10-
from torch.nn.parallel import DistributedDataParallel
11-
127
from diffulab.networks.denoisers import Denoiser
138

149

15-
class LossFunction(ABC, nn.Module): # to be completed
10+
class LossFunction(ABC, nn.Module):
11+
name: str = "extra_loss"
12+
1613
def __init__(self) -> None:
1714
super().__init__() # type: ignore
1815

19-
def save(self, path: str | Path, accelerator: "Accelerator") -> None:
20-
"""
21-
Save eventual learnable parameters of the loss function.
22-
23-
Args:
24-
path (str | Path): Path to save the loss function.
25-
accelerator (Accelerator | None): Accelerator instance for distributed training. Uses
26-
accelerator.save if provided
27-
"""
28-
# By default, doesn't save anything.
29-
pass
30-
31-
def accelerate_prepare(
32-
self, accelerator: "Accelerator"
33-
) -> "list[nn.Module | DistributedDataParallel | FullyShardedDataParallel]":
34-
"""
35-
Prepare the loss function for distributed training.
36-
37-
Args:
38-
accelerator (Accelerator): Accelerator instance for distributed training.
39-
"""
40-
# By default, doesn't prepare anything.
41-
return []
42-
4316
def set_model(self, model: "Denoiser") -> None:
4417
"""
4518
Set the model for the loss function.

src/diffulab/training/losses/repa.py

Lines changed: 112 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
1-
from pathlib import Path
2-
from typing import TYPE_CHECKING, Any, TypedDict, cast
1+
from typing import Any, Callable, TypedDict
2+
from weakref import WeakKeyDictionary
33

44
import torch
5-
from accelerate import Accelerator # type: ignore
65
from jaxtyping import Float
76
from torch import Tensor, nn
7+
from torch.utils.hooks import RemovableHandle
88

99
from diffulab.networks.denoisers.mmdit import MMDiT
1010
from diffulab.networks.repa.common import REPA
1111
from diffulab.networks.repa.dinov2 import DinoV2
1212
from diffulab.networks.repa.perceiver_resampler import PerceiverResampler
1313
from diffulab.training.losses.common import LossFunction
1414

15-
if TYPE_CHECKING:
16-
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
17-
from torch.nn.parallel import DistributedDataParallel
15+
try:
16+
from torch._dynamo import disable as _dynamo_disable # type: ignore
17+
except Exception:
18+
19+
def _dynamo_disable(fn: Any) -> Any:
20+
return fn
1821

1922

2023
class ResamplerParams(TypedDict):
@@ -27,7 +30,53 @@ class ResamplerParams(TypedDict):
2730

2831

2932
class RepaLoss(LossFunction):
33+
"""Representation Alignment (REPA) loss.
34+
35+
Aligns intermediate features from a denoiser (MMDiT) to features from an
36+
external vision encoder (e.g., DINOv2) using a projection MLP and, optionally,
37+
a Perceiver resampler. Denoiser features are captured via a forward hook on a
38+
specified transformer block and compared to encoder features using cosine
39+
similarity. The loss is averaged over the sequence dimension and scaled by
40+
``coeff``.
41+
42+
Typical usage:
43+
loss_fn = RepaLoss(...)
44+
loss_fn.set_model(denoiser)
45+
# Run a forward pass through the denoiser to populate captured features
46+
loss = loss_fn(x0=batch_images) # or pass dst_features=...
47+
48+
Args:
49+
repa_encoder: Key of the encoder to instantiate. Supported values are
50+
keys of ``encoder_registry``, e.g. "dinov2".
51+
encoder_args: Keyword arguments forwarded to the encoder constructor.
52+
alignment_layer: 1-based index of the MMDiT layer from which to capture
53+
features.
54+
denoiser_dimension: Feature dimensionality of the denoiser at the
55+
alignment layer.
56+
hidden_dim: Hidden size of the projection MLP.
57+
load_dino: Whether to instantiate and load the encoder. Set to ``False``
58+
when precomputed ``dst_features`` will be supplied at call time.
59+
embedding_dim: Target embedding dimensionality when the encoder is not
60+
instantiated (i.e., when ``load_dino=False``).
61+
use_resampler: Whether to apply a :class:`PerceiverResampler` after the
62+
projection MLP.
63+
resampler_params: Configuration for the :class:`PerceiverResampler`.
64+
Required if ``use_resampler=True``.
65+
coeff: Multiplicative weight applied to the returned loss value.
66+
67+
Attributes:
68+
repa_encoder: The instantiated encoder or ``None`` if
69+
``load_dino=False``.
70+
proj: Projection MLP mapping denoiser features to the encoder embedding
71+
space.
72+
resampler: Optional :class:`PerceiverResampler` applied after the
73+
projection.
74+
alignment_layer: 1-based index of the hooked MMDiT layer.
75+
coeff: Multiplicative weight applied to the returned loss.
76+
"""
77+
3078
encoder_registry: dict[str, type[REPA]] = {"dinov2": DinoV2}
79+
name: str = "RepaLoss"
3180

3281
def __init__(
3382
self,
@@ -69,79 +118,77 @@ def __init__(
69118
**resampler_params,
70119
)
71120
self.alignment_layer = alignment_layer
72-
self._hook_handle = None
73-
self.src_features: Tensor | None = None
121+
self._handles: "WeakKeyDictionary[nn.Module, RemovableHandle]" = WeakKeyDictionary()
122+
self._features: "WeakKeyDictionary[nn.Module, torch.Tensor]" = WeakKeyDictionary()
123+
self._active_model: nn.Module | None = None
124+
self._hook_layer_idx = self.alignment_layer - 1 # as before
74125
self.coeff = coeff
75126

76-
def _register_hook(self, model: MMDiT) -> None:
77-
"""Register the forward hook on the specified layer of the model."""
78-
self._unregister_hook() # Ensure no previous hook is registered
79-
self._hook_handle = model.layers[self.alignment_layer - 1].register_forward_hook(self._forward_hook)
80-
81-
def _unregister_hook(self) -> None:
82-
"""Remove the forward hook."""
83-
if self._hook_handle is not None:
84-
self._hook_handle.remove()
85-
self._hook_handle = None
127+
@_dynamo_disable
128+
def _make_forward_hook(self, key_model: MMDiT) -> Callable[[nn.Module, tuple[Any, ...], torch.Tensor], None]:
129+
def _hook(_mod: nn.Module, _inp: tuple[Any, ...], out: torch.Tensor):
130+
self._features[key_model] = out
86131

87-
def set_model(self, model: MMDiT) -> None: # type: ignore
88-
"""Switch the hook to a different model (e.g., EMA model)."""
89-
self._register_hook(model)
132+
return _hook
90133

91-
def _forward_hook(self, net: nn.Module, input: tuple[Any, ...], output: Tensor) -> None:
92-
"""
93-
Hook to capture the output of the specified layer during the forward pass.
94-
"""
95-
self.src_features = output
96-
97-
def save(self, path: str | Path, accelerator: Accelerator) -> None:
98-
"""
99-
Save state dict containing projection (and resampler if present).
100-
101-
Args:
102-
path (str | Path): Path to save the loss function.
103-
accelerator (Accelerator | None): Accelerator instance for distributed training. Uses
104-
accelerator.save if provided.
105-
"""
106-
file_path = Path(path) / "RepaLoss.pt"
134+
def _attach_once(self, model: MMDiT) -> None:
135+
if model in self._handles:
136+
return
137+
layer = model.layers[self._hook_layer_idx]
138+
handle = layer.register_forward_hook(self._make_forward_hook(model)) # type: ignore
139+
self._handles[model] = handle
107140

108-
unwrapped_proj = cast(nn.Module, accelerator.unwrap_model(self.proj)) # type: ignore
109-
merged_state = {}
110-
for k, v in unwrapped_proj.state_dict().items():
111-
merged_state[f"proj.{k}"] = v
112-
if self.resampler is not None:
113-
unwrapped_resampler = cast(nn.Module, accelerator.unwrap_model(self.resampler)) # type: ignore
114-
for k, v in unwrapped_resampler.state_dict().items():
115-
merged_state[f"resampler.{k}"] = v
141+
def set_model(self, model: MMDiT) -> None: # type: ignore
142+
"""Register the model to capture features from a specific layer.
116143
117-
accelerator.save(merged_state, file_path) # type: ignore
118-
119-
def accelerate_prepare(
120-
self, accelerator: Accelerator
121-
) -> "list[nn.Module | DistributedDataParallel | FullyShardedDataParallel]":
122-
"""
123-
Prepare the loss function for distributed training.
144+
This attaches a forward hook to the specified ``alignment_layer`` of the
145+
provided model (only once). A forward pass on ``model`` must be executed
146+
after calling this method so that features are captured before computing
147+
the loss.
124148
125149
Args:
126-
accelerator (Accelerator): Accelerator instance for distributed training.
150+
model (MMDiT): The model whose intermediate features will be
151+
aligned to the encoder features.
127152
"""
128-
trainable_modules: "list[nn.Module | DistributedDataParallel | FullyShardedDataParallel]" = []
129-
self.proj = accelerator.prepare_model(self.proj) # type: ignore
130-
trainable_modules.append(self.proj) # type: ignore
131-
if self.resampler is not None:
132-
self.resampler = accelerator.prepare_model(self.resampler) # type: ignore
133-
trainable_modules.append(self.resampler) # type: ignore
134-
if self.repa_encoder is not None:
135-
self.repa_encoder = accelerator.prepare_model(self.repa_encoder) # type: ignore
153+
self._attach_once(model)
154+
self._active_model = model
136155

137-
return trainable_modules
156+
def _unregister_all(self) -> None:
157+
for h in list(self._handles.values()):
158+
h.remove()
159+
self._handles.clear()
160+
self._features.clear()
161+
self._active_model = None
138162

139163
def forward(
140164
self,
141165
x0: Float[Tensor, "batch 3 H W"] | None = None,
142166
dst_features: Float[Tensor, "batch seq_len n_dim"] | None = None,
143167
) -> Tensor:
144-
assert self.src_features is not None, "Source features are not computed. Ensure the forward hook is registered."
168+
"""Compute the REPA cosine-similarity loss.
169+
170+
Either provide input images via ``x0`` to compute destination features
171+
with the encoder, or pass precomputed ``dst_features`` directly.
172+
173+
Args:
174+
x0 (Tensor): Input images of shape ``[B, 3, H, W]`` used to compute encoder
175+
features when an encoder is available.
176+
dst_features (Tensor): Precomputed encoder features of shape ``[B, S, D]``.
177+
If provided, ``x0`` is ignored.
178+
179+
Returns:
180+
Tensor: A scalar tensor containing the REPA loss.
181+
182+
Raises:
183+
RuntimeError: If no captured features are available for the active
184+
model. Ensure ``set_model(...)`` was called and a forward pass
185+
on the model was executed first.
186+
AssertionError: If neither ``x0`` nor ``dst_features`` is provided.
187+
"""
188+
if self._active_model is None or self._active_model not in self._features:
189+
raise RuntimeError(
190+
"REPA: no captured features for the active model. Did you call set_model(...) and run a forward pass?"
191+
)
145192
assert x0 is not None or dst_features is not None, "Either x0 or dst_features must be provided."
146193
if dst_features is None:
147194
assert self.repa_encoder is not None, "REPA encoder must be initialized to compute features."
@@ -151,7 +198,8 @@ def forward(
151198
) # batch size seqlen embedding_dim # SEE HOW TO HANDLE THE PRE COMPUTING OF FEATURES
152199
assert dst_features is not None, "Destination features must be provided or computed."
153200

154-
projected_src_features: Tensor = self.proj(self.src_features) # type: ignore
201+
src_features = self._features[self._active_model]
202+
projected_src_features: Tensor = self.proj(src_features)
155203

156204
if self.resampler is not None:
157205
projected_src_features = self.resampler(projected_src_features)

0 commit comments

Comments
 (0)