Skip to content

Commit 41d903d

Browse files
committed
fix unittest
1 parent f9447e9 commit 41d903d

File tree

4 files changed

+24
-17
lines changed

4 files changed

+24
-17
lines changed

configs/datasets/voc.yml

+12-12
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@ map_type: 11point
33
num_classes: 20
44

55
TrainDataset:
6-
!VOCDataSet
7-
dataset_dir: dataset/voc
8-
anno_path: trainval.txt
9-
label_list: label_list.txt
10-
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
6+
name: VOCDataSet
7+
dataset_dir: dataset/voc
8+
anno_path: trainval.txt
9+
label_list: label_list.txt
10+
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
1111

1212
EvalDataset:
13-
!VOCDataSet
14-
dataset_dir: dataset/voc
15-
anno_path: test.txt
16-
label_list: label_list.txt
17-
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
13+
name: VOCDataSet
14+
dataset_dir: dataset/voc
15+
anno_path: test.txt
16+
label_list: label_list.txt
17+
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
1818

1919
TestDataset:
20-
!ImageFolder
21-
anno_path: dataset/voc/label_list.txt
20+
name: ImageFolder
21+
anno_path: dataset/voc/label_list.txt

configs/runtime.yml

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use_gpu: true
22
use_xpu: false
33
use_mlu: false
4+
use_npu: false
45
log_iter: 20
56
save_dir: output
67
snapshot_epoch: 1

ppdet/core/workspace.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ def __getattr__(self, key):
6767
return self[key]
6868
raise AttributeError("object has no attribute '{}'".format(key))
6969

70+
def copy(self):
71+
new_dict = AttrDict()
72+
for k, v in self.items():
73+
new_dict.update({k: v})
74+
return new_dict
75+
7076

7177
global_config = AttrDict()
7278

@@ -280,4 +286,4 @@ def create(cls_or_name, **kwargs):
280286
# prevent modification of global config values of reference types
281287
# (e.g., list, dict) from within the created module instances
282288
#kwargs = copy.deepcopy(kwargs)
283-
return cls(**cls_kwargs)
289+
return cls(**cls_kwargs)

ppdet/engine/trainer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262

6363
class Trainer(object):
6464
def __init__(self, cfg, mode='train'):
65-
self.cfg = cfg
65+
self.cfg = cfg.copy()
6666
assert mode.lower() in ['train', 'eval', 'test'], \
6767
"mode should be 'train', 'eval' or 'test'"
6868
self.mode = mode.lower()
@@ -99,12 +99,12 @@ def __init__(self, cfg, mode='train'):
9999
self.dataset, cfg.worker_num)
100100

101101
if cfg.architecture == 'JDE' and self.mode == 'train':
102-
cfg['JDEEmbeddingHead'][
102+
self.cfg['JDEEmbeddingHead'][
103103
'num_identities'] = self.dataset.num_identities_dict[0]
104104
# JDE only support single class MOT now.
105105

106106
if cfg.architecture == 'FairMOT' and self.mode == 'train':
107-
cfg['FairMOTEmbeddingHead'][
107+
self.cfg['FairMOTEmbeddingHead'][
108108
'num_identities_dict'] = self.dataset.num_identities_dict
109109
# FairMOT support single class and multi-class MOT now.
110110

@@ -149,7 +149,7 @@ def __init__(self, cfg, mode='train'):
149149
reader_name = '{}Reader'.format(self.mode.capitalize())
150150
# If metric is VOC, need to be set collate_batch=False.
151151
if cfg.metric == 'VOC':
152-
cfg[reader_name]['collate_batch'] = False
152+
self.cfg[reader_name]['collate_batch'] = False
153153
self.loader = create(reader_name)(self.dataset, cfg.worker_num,
154154
self._eval_batch_sampler)
155155
# TestDataset build after user set images, skip loader creation here

0 commit comments

Comments
 (0)