1
- from pathlib import Path
2
- from typing import TYPE_CHECKING , Any , TypedDict , cast
1
+ from typing import Any , Callable , TypedDict
2
+ from weakref import WeakKeyDictionary
3
3
4
4
import torch
5
- from accelerate import Accelerator # type: ignore
6
5
from jaxtyping import Float
7
6
from torch import Tensor , nn
7
+ from torch .utils .hooks import RemovableHandle
8
8
9
9
from diffulab .networks .denoisers .mmdit import MMDiT
10
10
from diffulab .networks .repa .common import REPA
11
11
from diffulab .networks .repa .dinov2 import DinoV2
12
12
from diffulab .networks .repa .perceiver_resampler import PerceiverResampler
13
13
from diffulab .training .losses .common import LossFunction
14
14
15
- if TYPE_CHECKING :
16
- from torch .distributed .fsdp .fully_sharded_data_parallel import FullyShardedDataParallel
17
- from torch .nn .parallel import DistributedDataParallel
15
+ try :
16
+ from torch ._dynamo import disable as _dynamo_disable # type: ignore
17
+ except Exception :
18
+
19
+ def _dynamo_disable (fn : Any ) -> Any :
20
+ return fn
18
21
19
22
20
23
class ResamplerParams (TypedDict ):
@@ -27,7 +30,53 @@ class ResamplerParams(TypedDict):
27
30
28
31
29
32
class RepaLoss (LossFunction ):
33
+ """Representation Alignment (REPA) loss.
34
+
35
+ Aligns intermediate features from a denoiser (MMDiT) to features from an
36
+ external vision encoder (e.g., DINOv2) using a projection MLP and, optionally,
37
+ a Perceiver resampler. Denoiser features are captured via a forward hook on a
38
+ specified transformer block and compared to encoder features using cosine
39
+ similarity. The loss is averaged over the sequence dimension and scaled by
40
+ ``coeff``.
41
+
42
+ Typical usage:
43
+ loss_fn = RepaLoss(...)
44
+ loss_fn.set_model(denoiser)
45
+ # Run a forward pass through the denoiser to populate captured features
46
+ loss = loss_fn(x0=batch_images) # or pass dst_features=...
47
+
48
+ Args:
49
+ repa_encoder: Key of the encoder to instantiate. Supported values are
50
+ keys of ``encoder_registry``, e.g. "dinov2".
51
+ encoder_args: Keyword arguments forwarded to the encoder constructor.
52
+ alignment_layer: 1-based index of the MMDiT layer from which to capture
53
+ features.
54
+ denoiser_dimension: Feature dimensionality of the denoiser at the
55
+ alignment layer.
56
+ hidden_dim: Hidden size of the projection MLP.
57
+ load_dino: Whether to instantiate and load the encoder. Set to ``False``
58
+ when precomputed ``dst_features`` will be supplied at call time.
59
+ embedding_dim: Target embedding dimensionality when the encoder is not
60
+ instantiated (i.e., when ``load_dino=False``).
61
+ use_resampler: Whether to apply a :class:`PerceiverResampler` after the
62
+ projection MLP.
63
+ resampler_params: Configuration for the :class:`PerceiverResampler`.
64
+ Required if ``use_resampler=True``.
65
+ coeff: Multiplicative weight applied to the returned loss value.
66
+
67
+ Attributes:
68
+ repa_encoder: The instantiated encoder or ``None`` if
69
+ ``load_dino=False``.
70
+ proj: Projection MLP mapping denoiser features to the encoder embedding
71
+ space.
72
+ resampler: Optional :class:`PerceiverResampler` applied after the
73
+ projection.
74
+ alignment_layer: 1-based index of the hooked MMDiT layer.
75
+ coeff: Multiplicative weight applied to the returned loss.
76
+ """
77
+
30
78
encoder_registry : dict [str , type [REPA ]] = {"dinov2" : DinoV2 }
79
+ name : str = "RepaLoss"
31
80
32
81
def __init__ (
33
82
self ,
@@ -69,79 +118,77 @@ def __init__(
69
118
** resampler_params ,
70
119
)
71
120
self .alignment_layer = alignment_layer
72
- self ._hook_handle = None
73
- self .src_features : Tensor | None = None
121
+ self ._handles : "WeakKeyDictionary[nn.Module, RemovableHandle]" = WeakKeyDictionary ()
122
+ self ._features : "WeakKeyDictionary[nn.Module, torch.Tensor]" = WeakKeyDictionary ()
123
+ self ._active_model : nn .Module | None = None
124
+ self ._hook_layer_idx = self .alignment_layer - 1 # as before
74
125
self .coeff = coeff
75
126
76
- def _register_hook (self , model : MMDiT ) -> None :
77
- """Register the forward hook on the specified layer of the model."""
78
- self ._unregister_hook () # Ensure no previous hook is registered
79
- self ._hook_handle = model .layers [self .alignment_layer - 1 ].register_forward_hook (self ._forward_hook )
80
-
81
- def _unregister_hook (self ) -> None :
82
- """Remove the forward hook."""
83
- if self ._hook_handle is not None :
84
- self ._hook_handle .remove ()
85
- self ._hook_handle = None
127
+ @_dynamo_disable
128
+ def _make_forward_hook (self , key_model : MMDiT ) -> Callable [[nn .Module , tuple [Any , ...], torch .Tensor ], None ]:
129
+ def _hook (_mod : nn .Module , _inp : tuple [Any , ...], out : torch .Tensor ):
130
+ self ._features [key_model ] = out
86
131
87
- def set_model (self , model : MMDiT ) -> None : # type: ignore
88
- """Switch the hook to a different model (e.g., EMA model)."""
89
- self ._register_hook (model )
132
+ return _hook
90
133
91
- def _forward_hook (self , net : nn .Module , input : tuple [Any , ...], output : Tensor ) -> None :
92
- """
93
- Hook to capture the output of the specified layer during the forward pass.
94
- """
95
- self .src_features = output
96
-
97
- def save (self , path : str | Path , accelerator : Accelerator ) -> None :
98
- """
99
- Save state dict containing projection (and resampler if present).
100
-
101
- Args:
102
- path (str | Path): Path to save the loss function.
103
- accelerator (Accelerator | None): Accelerator instance for distributed training. Uses
104
- accelerator.save if provided.
105
- """
106
- file_path = Path (path ) / "RepaLoss.pt"
134
+ def _attach_once (self , model : MMDiT ) -> None :
135
+ if model in self ._handles :
136
+ return
137
+ layer = model .layers [self ._hook_layer_idx ]
138
+ handle = layer .register_forward_hook (self ._make_forward_hook (model )) # type: ignore
139
+ self ._handles [model ] = handle
107
140
108
- unwrapped_proj = cast (nn .Module , accelerator .unwrap_model (self .proj )) # type: ignore
109
- merged_state = {}
110
- for k , v in unwrapped_proj .state_dict ().items ():
111
- merged_state [f"proj.{ k } " ] = v
112
- if self .resampler is not None :
113
- unwrapped_resampler = cast (nn .Module , accelerator .unwrap_model (self .resampler )) # type: ignore
114
- for k , v in unwrapped_resampler .state_dict ().items ():
115
- merged_state [f"resampler.{ k } " ] = v
141
+ def set_model (self , model : MMDiT ) -> None : # type: ignore
142
+ """Register the model to capture features from a specific layer.
116
143
117
- accelerator .save (merged_state , file_path ) # type: ignore
118
-
119
- def accelerate_prepare (
120
- self , accelerator : Accelerator
121
- ) -> "list[nn.Module | DistributedDataParallel | FullyShardedDataParallel]" :
122
- """
123
- Prepare the loss function for distributed training.
144
+ This attaches a forward hook to the specified ``alignment_layer`` of the
145
+ provided model (only once). A forward pass on ``model`` must be executed
146
+ after calling this method so that features are captured before computing
147
+ the loss.
124
148
125
149
Args:
126
- accelerator (Accelerator): Accelerator instance for distributed training.
150
+ model (MMDiT): The model whose intermediate features will be
151
+ aligned to the encoder features.
127
152
"""
128
- trainable_modules : "list[nn.Module | DistributedDataParallel | FullyShardedDataParallel]" = []
129
- self .proj = accelerator .prepare_model (self .proj ) # type: ignore
130
- trainable_modules .append (self .proj ) # type: ignore
131
- if self .resampler is not None :
132
- self .resampler = accelerator .prepare_model (self .resampler ) # type: ignore
133
- trainable_modules .append (self .resampler ) # type: ignore
134
- if self .repa_encoder is not None :
135
- self .repa_encoder = accelerator .prepare_model (self .repa_encoder ) # type: ignore
153
+ self ._attach_once (model )
154
+ self ._active_model = model
136
155
137
- return trainable_modules
156
+ def _unregister_all (self ) -> None :
157
+ for h in list (self ._handles .values ()):
158
+ h .remove ()
159
+ self ._handles .clear ()
160
+ self ._features .clear ()
161
+ self ._active_model = None
138
162
139
163
def forward (
140
164
self ,
141
165
x0 : Float [Tensor , "batch 3 H W" ] | None = None ,
142
166
dst_features : Float [Tensor , "batch seq_len n_dim" ] | None = None ,
143
167
) -> Tensor :
144
- assert self .src_features is not None , "Source features are not computed. Ensure the forward hook is registered."
168
+ """Compute the REPA cosine-similarity loss.
169
+
170
+ Either provide input images via ``x0`` to compute destination features
171
+ with the encoder, or pass precomputed ``dst_features`` directly.
172
+
173
+ Args:
174
+ x0 (Tensor): Input images of shape ``[B, 3, H, W]`` used to compute encoder
175
+ features when an encoder is available.
176
+ dst_features (Tensor): Precomputed encoder features of shape ``[B, S, D]``.
177
+ If provided, ``x0`` is ignored.
178
+
179
+ Returns:
180
+ Tensor: A scalar tensor containing the REPA loss.
181
+
182
+ Raises:
183
+ RuntimeError: If no captured features are available for the active
184
+ model. Ensure ``set_model(...)`` was called and a forward pass
185
+ on the model was executed first.
186
+ AssertionError: If neither ``x0`` nor ``dst_features`` is provided.
187
+ """
188
+ if self ._active_model is None or self ._active_model not in self ._features :
189
+ raise RuntimeError (
190
+ "REPA: no captured features for the active model. Did you call set_model(...) and run a forward pass?"
191
+ )
145
192
assert x0 is not None or dst_features is not None , "Either x0 or dst_features must be provided."
146
193
if dst_features is None :
147
194
assert self .repa_encoder is not None , "REPA encoder must be initialized to compute features."
@@ -151,7 +198,8 @@ def forward(
151
198
) # batch size seqlen embedding_dim # SEE HOW TO HANDLE THE PRE COMPUTING OF FEATURES
152
199
assert dst_features is not None , "Destination features must be provided or computed."
153
200
154
- projected_src_features : Tensor = self .proj (self .src_features ) # type: ignore
201
+ src_features = self ._features [self ._active_model ]
202
+ projected_src_features : Tensor = self .proj (src_features )
155
203
156
204
if self .resampler is not None :
157
205
projected_src_features = self .resampler (projected_src_features )
0 commit comments