Skip to content

Commit 5966423

Browse files
authored
šŸž fix(model): EfficientAD NotImplementedError: ('{} cannot be pickled', '_SingleProcessDataLoaderIter') (#2837)
* fix efficientad example and pickling error * pre-commit hooks Signed-off-by: Alexander Riedel <alex.riedel@googlemail.com> --------- Signed-off-by: Alexander Riedel <alex.riedel@googlemail.com>
1 parent 90e1192 commit 5966423

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

ā€Žexamples/api/03_models/efficient_ad.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,38 @@
77
particularly well-suited for industrial inspection tasks.
88
"""
99

10+
from torchvision.transforms.v2 import Compose, Resize
11+
1012
from anomalib.data import MVTecAD
1113
from anomalib.engine import Engine
1214
from anomalib.models import EfficientAd
15+
from anomalib.pre_processing import PreProcessor
1316

1417
# 1. Basic Usage
1518
# Initialize with default settings
1619
model = EfficientAd()
1720

18-
# 2. Custom Configuration
21+
# OPTIONAL 2. Set up a pre-processing transformation
22+
transform = Compose([
23+
Resize(size=(512, 512)),
24+
])
25+
pre_processor = PreProcessor(transform=transform)
26+
27+
# 3. Custom Configuration
1928
# Configure model parameters
2029
model = EfficientAd(
2130
teacher_out_channels=384, # Number of teacher output channels
22-
model_size="m",
31+
model_size="medium",
2332
lr=1e-4,
33+
pre_processor=pre_processor,
2434
)
2535

26-
# 3. Training Pipeline
36+
# 4. Training Pipeline
2737
# Set up the complete training pipeline
2838
datamodule = MVTecAD(
2939
root="./datasets/MVTecAD",
3040
category="bottle",
31-
train_batch_size=32,
41+
train_batch_size=1,
3242
)
3343

3444
# Initialize training engine with specific settings
@@ -38,8 +48,9 @@
3848
devices=1, # Number of devices to use
3949
)
4050

41-
# Train the model
42-
engine.fit(
43-
model=model,
44-
datamodule=datamodule,
45-
)
51+
if __name__ == "__main__":
52+
# Train the model
53+
engine.fit(
54+
model=model,
55+
datamodule=datamodule,
56+
)

ā€Žsrc/anomalib/models/image/efficient_ad/lightning_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,16 @@ def on_train_start(self) -> None:
380380
channel_mean_std = self.teacher_channel_mean_std(self.trainer.datamodule.train_dataloader())
381381
self.model.mean_std.update(channel_mean_std)
382382

383+
def __getstate__(self) -> dict:
384+
"""Modifies the python objects __getstate__ method.
385+
386+
To ensure that the imagenet iterator instance will not get pickled
387+
when the model is saved, it needs to be removed from the objects dict.
388+
"""
389+
state = self.__dict__.copy()
390+
state.pop("imagenet_iterator", None)
391+
return state
392+
383393
def training_step(self, batch: Batch, *args, **kwargs) -> dict[str, torch.Tensor]:
384394
"""Perform training step.
385395

0 commit comments

Comments
Ā (0)