Skip to content

Commit fe6f614

Browse files
committed
Update multilabel
1 parent a90881c commit fe6f614

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ DataLoader:
4646
Train:
4747
dataset:
4848
name: MultiLabelDataset
49-
image_root: ./dataset/NUS-SCENE-dataset/images/
50-
cls_label_path: ./dataset/NUS-SCENE-dataset/multilabel_train_list.txt
49+
image_root: ./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images/
50+
cls_label_path: ./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/multilabel_train_list.txt
5151
transform_ops:
5252
- DecodeImage:
5353
to_rgb: True
@@ -74,8 +74,8 @@ DataLoader:
7474
Eval:
7575
dataset:
7676
name: MultiLabelDataset
77-
image_root: ./dataset/NUS-SCENE-dataset/images/
78-
cls_label_path: ./dataset/NUS-SCENE-dataset/multilabel_test_list.txt
77+
image_root: ./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images/
78+
cls_label_path: ./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/multilabel_test_list.txt
7979
transform_ops:
8080
- DecodeImage:
8181
to_rgb: True

ppcls/engine/evaluation/classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def classification_eval(engine, epoch_id=0):
5050
time_info["reader_cost"].update(time.time() - tic)
5151
batch_size = batch[0].shape[0]
5252
batch[0] = paddle.to_tensor(batch[0]).astype("float32")
53-
if not evaler.config["Global"].get("use_multilabel", False):
53+
if not engine.config["Global"].get("use_multilabel", False):
5454
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
5555
# image input
5656
out = engine.model(batch[0])

ppcls/engine/train/train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def train_epoch(engine, epoch_id, print_batch_step):
7676
tic = time.time()
7777

7878

79-
def forward(trainer, batch):
80-
if not trainer.is_rec:
81-
return trainer.model(batch[0])
79+
def forward(engine, batch):
80+
if not engine.is_rec:
81+
return engine.model(batch[0])
8282
else:
83-
return trainer.model(batch[0], batch[1])
83+
return engine.model(batch[0], batch[1])

0 commit comments

Comments
 (0)