Skip to content

Commit 73073a0

Browse files
🔄 refactor(model): memory usage optimisation (#2813)
* inital commit: Don't re-init buffer, inplace fill Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * implement patchcore memory savings Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * deprecation warning for internal subsample_embedding of patchcore Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * PaDiM gpu memory usage reduction Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * refactor memory bank tensor assignment for readability. dfkdde refactor Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * unify memory bank models trainer arguments. refactor dfm to fit new memory bank framework Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * bugfix: padim device count set to zero Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * ensure buffer is not replaced but instead resized and filled Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * give memory bank type Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * revert memory bank mixin back Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * to memory bank type and device Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * update return type, add comment about deprecation, duplicated copyright Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * new deprecation wrapper, now handles arg deprecation Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * add flexible tests for new deprecation warning Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * update warnings to deprecations. Resolve comments re-fit vs fit_guassian. Test for None replacements for args Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * add type hint for dfkde Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * remove dataset arg from dfm model fit Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * Update src/anomalib/models/image/dfm/torch_model.py Signed-off-by: Samet Akcay <samet.akcay@intel.com> --------- Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Signed-off-by: alfieroddan <51797647+alfieroddan@users.noreply.github.com> Signed-off-by: Samet Akcay <samet.akcay@intel.com> Co-authored-by: Samet Akcay <samet.akcay@intel.com>
1 parent dcf0820 commit 73073a0

File tree

10 files changed

+452
-134
lines changed

10 files changed

+452
-134
lines changed

src/anomalib/models/image/dfkde/lightning_model.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(
104104
visualizer=visualizer,
105105
)
106106

107-
self.model = DfkdeModel(
107+
self.model: DfkdeModel = DfkdeModel(
108108
layers=layers,
109109
backbone=backbone,
110110
pre_trained=pre_trained,
@@ -113,8 +113,6 @@ def __init__(
113113
max_training_points=max_training_points,
114114
)
115115

116-
self.embeddings: list[torch.Tensor] = []
117-
118116
@staticmethod
119117
def configure_optimizers() -> None: # pylint: disable=arguments-differ
120118
"""DFKDE doesn't require optimization, therefore returns no optimizers."""
@@ -133,18 +131,15 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
133131
"""
134132
del args, kwargs # These variables are not used.
135133

136-
embedding = self.model(batch.image)
137-
self.embeddings.append(embedding)
134+
_ = self.model(batch.image)
138135

139136
# Return a dummy loss tensor
140137
return torch.tensor(0.0, requires_grad=True, device=self.device)
141138

142139
def fit(self) -> None:
143140
"""Fit KDE model to collected embeddings from the training set."""
144-
embeddings = torch.vstack(self.embeddings)
145-
146141
logger.info("Fitting a KDE model to the embedding collected from the training set.")
147-
self.model.classifier.fit(embeddings)
142+
self.model.fit()
148143

149144
def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:
150145
"""Perform validation by computing anomaly scores.
@@ -167,9 +162,13 @@ def trainer_arguments(self) -> dict[str, Any]:
167162
"""Get DFKDE-specific trainer arguments.
168163
169164
Returns:
170-
dict[str, Any]: Dictionary of trainer arguments.
165+
dict[str, Any]: Trainer arguments
166+
- ``gradient_clip_val``: ``0`` (no gradient clipping needed)
167+
- ``max_epochs``: ``1`` (single pass through training data)
168+
- ``num_sanity_val_steps``: ``0`` (skip validation sanity checks)
169+
- ``devices``: ``1`` (only single gpu supported)
171170
"""
172-
return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0}
171+
return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0, "devices": 1}
173172

174173
@property
175174
def learning_type(self) -> LearningType:

src/anomalib/models/image/dfkde/torch_model.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(
8989
feature_scaling_method=feature_scaling_method,
9090
max_training_points=max_training_points,
9191
)
92+
self.memory_bank = torch.empty(0)
9293

9394
def get_features(self, batch: torch.Tensor) -> torch.Tensor:
9495
"""Extract features from the pre-trained backbone network.
@@ -141,8 +142,34 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch:
141142
# 1. apply feature extraction
142143
features = self.get_features(batch)
143144
if self.training:
145+
if self.memory_bank.size(0) == 0:
146+
self.memory_bank = features
147+
else:
148+
new_bank = torch.cat((self.memory_bank, features), dim=0).to(self.memory_bank)
149+
self.memory_bank = new_bank
144150
return features
145151

146152
# 2. apply density estimation
147153
scores = self.classifier(features)
148154
return InferenceBatch(pred_score=scores)
155+
156+
def fit(self) -> None:
157+
"""Fits the classifier using the current contents of the memory bank.
158+
159+
This method is typically called after the memory bank has been populated
160+
during training.
161+
162+
After fitting, the memory bank is cleared to reduce GPU memory usage.
163+
164+
Raises:
165+
ValueError: If the memory bank is empty.
166+
"""
167+
if self.memory_bank.size(0) == 0:
168+
msg = "Memory bank is empty. Cannot perform coreset selection."
169+
raise ValueError(msg)
170+
171+
# fit gaussian
172+
self.classifier.fit(self.memory_bank)
173+
174+
# clear memory bank, redcues gpu size
175+
self.memory_bank = torch.empty(0).to(self.memory_bank)

src/anomalib/models/image/dfm/lightning_model.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def __init__(
112112
n_comps=pca_level,
113113
score_type=score_type,
114114
)
115-
self.embeddings: list[torch.Tensor] = []
116115
self.score_type = score_type
117116

118117
@staticmethod
@@ -137,8 +136,7 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
137136
"""
138137
del args, kwargs # These variables are not used.
139138

140-
embedding = self.model.get_features(batch.image).squeeze()
141-
self.embeddings.append(embedding)
139+
_ = self.model(batch.image)
142140

143141
# Return a dummy loss tensor
144142
return torch.tensor(0.0, requires_grad=True, device=self.device)
@@ -149,11 +147,8 @@ def fit(self) -> None:
149147
The method aggregates embeddings collected during training and fits
150148
both the PCA transformation and Gaussian model used for scoring.
151149
"""
152-
logger.info("Aggregating the embedding extracted from the training set.")
153-
embeddings = torch.vstack(self.embeddings)
154-
155150
logger.info("Fitting a PCA and a Gaussian model to dataset.")
156-
self.model.fit(embeddings)
151+
self.model.fit()
157152

158153
def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:
159154
"""Compute predictions for the input batch during validation.
@@ -176,12 +171,13 @@ def trainer_arguments(self) -> dict[str, Any]:
176171
"""Get DFM-specific trainer arguments.
177172
178173
Returns:
179-
dict[str, Any]: Dictionary of trainer arguments:
180-
- ``gradient_clip_val`` (int): Disable gradient clipping
181-
- ``max_epochs`` (int): Train for one epoch only
182-
- ``num_sanity_val_steps`` (int): Skip validation sanity checks
174+
dict[str, Any]: Trainer arguments
175+
- ``gradient_clip_val``: ``0`` (no gradient clipping needed)
176+
- ``max_epochs``: ``1`` (single pass through training data)
177+
- ``num_sanity_val_steps``: ``0`` (skip validation sanity checks)
178+
- ``devices``: ``1`` (only single gpu supported)
183179
"""
184-
return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0}
180+
return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0, "devices": 1}
185181

186182
@property
187183
def learning_type(self) -> LearningType:

src/anomalib/models/image/dfm/torch_model.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -153,18 +153,18 @@ def __init__(
153153
layers=[layer],
154154
).eval()
155155

156-
def fit(self, dataset: torch.Tensor) -> None:
157-
"""Fit PCA and Gaussian model to dataset.
156+
self.memory_bank = torch.empty(0)
158157

159-
Args:
160-
dataset (torch.Tensor): Input dataset with shape
161-
``(n_samples, n_features)``.
162-
"""
163-
self.pca_model.fit(dataset)
158+
def fit(self) -> None:
159+
"""Fit PCA and Gaussian model to dataset."""
160+
self.pca_model.fit(self.memory_bank)
164161
if self.score_type == "nll":
165-
features_reduced = self.pca_model.transform(dataset)
162+
features_reduced = self.pca_model.transform(self.memory_bank)
166163
self.gaussian_model.fit(features_reduced.T)
167164

165+
# clear memory bank, reduces GPU size
166+
self.memory_bank = torch.empty(0).to(self.memory_bank)
167+
168168
def score(self, features: torch.Tensor, feature_shapes: tuple) -> torch.Tensor:
169169
"""Compute anomaly scores.
170170
@@ -194,25 +194,24 @@ def score(self, features: torch.Tensor, feature_shapes: tuple) -> torch.Tensor:
194194

195195
return (score, None) if self.score_type == "nll" else (score, score_map)
196196

197-
def get_features(self, batch: torch.Tensor) -> torch.Tensor:
197+
def get_features(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Size]:
198198
"""Extract features from the pretrained network.
199199
200200
Args:
201201
batch (torch.Tensor): Input images with shape
202202
``(batch_size, channels, height, width)``.
203203
204204
Returns:
205-
Union[torch.Tensor, Tuple[torch.Tensor, torch.Size]]: Features during
206-
training, or tuple of (features, feature_shapes) during inference.
205+
tuple of (features, feature_shapes).
207206
"""
208-
self.feature_extractor.eval()
209-
features = self.feature_extractor(batch)[self.layer]
210-
batch_size = len(features)
211-
if self.pooling_kernel_size > 1:
212-
features = F.avg_pool2d(input=features, kernel_size=self.pooling_kernel_size)
213-
feature_shapes = features.shape
214-
features = features.view(batch_size, -1).detach()
215-
return features if self.training else (features, feature_shapes)
207+
with torch.no_grad():
208+
features = self.feature_extractor(batch)[self.layer]
209+
batch_size = len(features)
210+
if self.pooling_kernel_size > 1:
211+
features = F.avg_pool2d(input=features, kernel_size=self.pooling_kernel_size)
212+
feature_shapes = features.shape
213+
features = features.view(batch_size, -1)
214+
return features, feature_shapes
216215

217216
def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch:
218217
"""Compute anomaly predictions from input images.
@@ -227,6 +226,15 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch:
227226
``InferenceBatch`` with prediction scores and anomaly maps.
228227
"""
229228
feature_vector, feature_shapes = self.get_features(batch)
229+
230+
if self.training:
231+
if self.memory_bank.size(0) == 0:
232+
self.memory_bank = feature_vector
233+
else:
234+
new_bank = torch.cat((self.memory_bank, feature_vector), dim=0).to(self.memory_bank)
235+
self.memory_bank = new_bank
236+
return feature_vector
237+
230238
pred_score, anomaly_map = self.score(feature_vector.view(feature_vector.shape[:2]), feature_shapes)
231239
if anomaly_map is not None:
232240
anomaly_map = F.interpolate(anomaly_map, size=batch.shape[-2:], mode="bilinear", align_corners=False)

src/anomalib/models/image/padim/lightning_model.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,6 @@ def __init__(
131131
n_features=n_features,
132132
)
133133

134-
self.stats: list[torch.Tensor] = []
135-
self.embeddings: list[torch.Tensor] = []
136-
137134
@staticmethod
138135
def configure_optimizers() -> None:
139136
"""PADIM doesn't require optimization, therefore returns no optimizers."""
@@ -154,19 +151,15 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
154151
"""
155152
del args, kwargs # These variables are not used.
156153

157-
embedding = self.model(batch.image)
158-
self.embeddings.append(embedding)
154+
_ = self.model(batch.image)
159155

160156
# Return a dummy loss tensor
161157
return torch.tensor(0.0, requires_grad=True, device=self.device)
162158

163159
def fit(self) -> None:
164160
"""Fit a Gaussian to the embedding collected from the training set."""
165-
logger.info("Aggregating the embedding extracted from the training set.")
166-
embeddings = torch.vstack(self.embeddings)
167-
168161
logger.info("Fitting a Gaussian to the embedding collected from the training set.")
169-
self.stats = self.model.gaussian.fit(embeddings)
162+
self.model.fit()
170163

171164
def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:
172165
"""Perform a validation step of PADIM.
@@ -190,16 +183,16 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:
190183

191184
@property
192185
def trainer_arguments(self) -> dict[str, int | float]:
193-
"""Return PADIM trainer arguments.
194-
195-
Since the model does not require training, we limit the max_epochs to 1.
196-
Since we need to run training epoch before validation, we also set the
197-
sanity steps to 0.
186+
"""Get default trainer arguments for Padim.
198187
199188
Returns:
200-
dict[str, int | float]: Dictionary of trainer arguments
189+
dict[str, Any]: Trainer arguments
190+
- ``max_epochs``: ``1`` (single pass through training data)
191+
- ``val_check_interval``: ``1.0`` (check validation every 1 step)
192+
- ``num_sanity_val_steps``: ``0`` (skip validation sanity checks)
193+
- ``devices``: ``1`` (only single gpu supported)
201194
"""
202-
return {"max_epochs": 1, "val_check_interval": 1.0, "num_sanity_val_steps": 0}
195+
return {"max_epochs": 1, "val_check_interval": 1.0, "num_sanity_val_steps": 0, "devices": 1}
203196

204197
@property
205198
def learning_type(self) -> LearningType:

src/anomalib/models/image/padim/torch_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def __init__(
147147
self.anomaly_map_generator = AnomalyMapGenerator()
148148

149149
self.gaussian = MultiVariateGaussian()
150+
self.memory_bank = torch.empty(0)
150151

151152
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch:
152153
"""Forward-pass image-batch (N, C, H, W) into model to extract features.
@@ -182,6 +183,11 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch:
182183
embeddings = self.tiler.untile(embeddings)
183184

184185
if self.training:
186+
if self.memory_bank.size(0) == 0:
187+
self.memory_bank = embeddings
188+
else:
189+
new_bank = torch.cat((self.memory_bank, embeddings), dim=0).to(self.memory_bank)
190+
self.memory_bank = new_bank
185191
return embeddings
186192

187193
anomaly_map = self.anomaly_map_generator(
@@ -217,3 +223,23 @@ def generate_embedding(self, features: dict[str, torch.Tensor]) -> torch.Tensor:
217223
# subsample embeddings
218224
idx = self.idx.to(embeddings.device)
219225
return torch.index_select(embeddings, 1, idx)
226+
227+
def fit(self) -> None:
228+
"""Fits a Gaussian model to the current contents of the memory bank.
229+
230+
This method is typically called after the memory bank has been filled during training.
231+
232+
After fitting, the memory bank is cleared to free GPU memory before validation or testing.
233+
234+
Raises:
235+
ValueError: If the memory bank is empty.
236+
"""
237+
if self.memory_bank.size(0) == 0:
238+
msg = "Memory bank is empty. Cannot perform coreset selection."
239+
raise ValueError(msg)
240+
241+
# fit gaussian
242+
self.gaussian.fit(self.memory_bank)
243+
244+
# clear memory bank, redcues gpu usage
245+
self.memory_bank = torch.empty(0).to(self.memory_bank)

src/anomalib/models/image/patchcore/lightning_model.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def __init__(
159159
num_neighbors=num_neighbors,
160160
)
161161
self.coreset_sampling_ratio = coreset_sampling_ratio
162-
self.embeddings: list[torch.Tensor] = []
163162

164163
@classmethod
165164
def configure_pre_processor(
@@ -235,24 +234,18 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
235234
``fit()``.
236235
"""
237236
del args, kwargs # These variables are not used.
238-
239-
embedding = self.model(batch.image)
240-
self.embeddings.append(embedding)
237+
_ = self.model(batch.image)
241238
# Return a dummy loss tensor
242239
return torch.tensor(0.0, requires_grad=True, device=self.device)
243240

244241
def fit(self) -> None:
245242
"""Apply subsampling to the embedding collected from the training set.
246243
247244
This method:
248-
1. Aggregates embeddings from all training batches
249-
2. Applies coreset subsampling to reduce memory requirements
245+
1. Applies coreset subsampling to reduce memory requirements
250246
"""
251-
logger.info("Aggregating the embedding extracted from the training set.")
252-
embeddings = torch.vstack(self.embeddings)
253-
254247
logger.info("Applying core-set subsampling to get the embedding.")
255-
self.model.subsample_embedding(embeddings, self.coreset_sampling_ratio)
248+
self.model.subsample_embedding(self.coreset_sampling_ratio)
256249

257250
def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:
258251
"""Generate predictions for a batch of images.
@@ -286,8 +279,9 @@ def trainer_arguments(self) -> dict[str, Any]:
286279
- ``gradient_clip_val``: ``0`` (no gradient clipping needed)
287280
- ``max_epochs``: ``1`` (single pass through training data)
288281
- ``num_sanity_val_steps``: ``0`` (skip validation sanity checks)
282+
- ``devices``: ``1`` (only single gpu supported)
289283
"""
290-
return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0}
284+
return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0, "devices": 1}
291285

292286
@property
293287
def learning_type(self) -> LearningType:

0 commit comments

Comments
 (0)