Skip to content

Commit c4f3845

Browse files
authored
Merge pull request #1205 from weisy11/develop
add MixDataset, MixSampler and PKSampler
2 parents 9f13876 + da25931 commit c4f3845

File tree

11 files changed

+322
-82
lines changed

11 files changed

+322
-82
lines changed

ppcls/configs/Logo/ResNet50_ReID.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ Optimizer:
5454
momentum: 0.9
5555
lr:
5656
name: Cosine
57-
learning_rate: 0.01
57+
learning_rate: 0.04
5858
regularizer:
5959
name: 'L2'
6060
coeff: 0.0001
@@ -84,10 +84,10 @@ DataLoader:
8484
- RandomErasing:
8585
EPSILON: 0.5
8686
sampler:
87-
name: DistributedRandomIdentitySampler
87+
name: PKSampler
8888
batch_size: 128
89-
num_instances: 2
90-
drop_last: False
89+
sample_per_id: 2
90+
drop_last: True
9191

9292
loader:
9393
num_workers: 6
@@ -97,7 +97,7 @@ DataLoader:
9797
dataset:
9898
name: LogoDataset
9999
image_root: "dataset/LogoDet-3K-crop/val/"
100-
cls_label_path: "dataset/LogoDet-3K-crop/LogoDet-3K+query.txt"
100+
cls_label_path: "dataset/LogoDet-3K-crop/LogoDet-3K+val.txt"
101101
transform_ops:
102102
- DecodeImage:
103103
to_rgb: True
@@ -122,7 +122,7 @@ DataLoader:
122122
dataset:
123123
name: LogoDataset
124124
image_root: "dataset/LogoDet-3K-crop/train/"
125-
cls_label_path: "dataset/LogoDet-3K-crop/LogoDet-3K+gallery.txt"
125+
cls_label_path: "dataset/LogoDet-3K-crop/LogoDet-3K+train.txt"
126126
transform_ops:
127127
- DecodeImage:
128128
to_rgb: True

ppcls/configs/Products/ResNet50_vd_Inshop.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ Optimizer:
5454
momentum: 0.9
5555
lr:
5656
name: MultiStepDecay
57-
learning_rate: 0.01
57+
learning_rate: 0.04
5858
milestones: [30, 60, 70, 80, 90, 100]
5959
gamma: 0.5
6060
verbose: False
@@ -90,10 +90,10 @@ DataLoader:
9090
r1: 0.3
9191
mean: [0., 0., 0.]
9292
sampler:
93-
name: DistributedRandomIdentitySampler
93+
name: PKSampler
9494
batch_size: 64
95-
num_instances: 2
96-
drop_last: False
95+
sample_per_id: 2
96+
drop_last: True
9797
shuffle: True
9898
loader:
9999
num_workers: 4

ppcls/configs/Vehicle/ResNet50_ReID.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Optimizer:
5353
momentum: 0.9
5454
lr:
5555
name: Cosine
56-
learning_rate: 0.01
56+
learning_rate: 0.04
5757
regularizer:
5858
name: 'L2'
5959
coeff: 0.0005
@@ -88,10 +88,10 @@ DataLoader:
8888
mean: [0., 0., 0.]
8989

9090
sampler:
91-
name: DistributedRandomIdentitySampler
91+
name: PKSampler
9292
batch_size: 128
93-
num_instances: 2
94-
drop_last: False
93+
sample_per_id: 2
94+
drop_last: True
9595
shuffle: True
9696
loader:
9797
num_workers: 6

ppcls/data/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,12 @@
2626
from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
2727
from ppcls.data.dataloader.logo_dataset import LogoDataset
2828
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
29+
from ppcls.data.dataloader.mix_dataset import MixDataset
2930

3031
# sampler
3132
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
33+
from ppcls.data.dataloader.pk_sampler import PKSampler
34+
from ppcls.data.dataloader.mix_sampler import MixSampler
3235
from ppcls.data import preprocess
3336
from ppcls.data.preprocess import transform
3437

ppcls/data/dataloader/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from ppcls.data.dataloader.imagenet_dataset import ImageNetDataset
2+
from ppcls.data.dataloader.multilabel_dataset import MultiLabelDataset
3+
from ppcls.data.dataloader.common_dataset import create_operators
4+
from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
5+
from ppcls.data.dataloader.logo_dataset import LogoDataset
6+
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
7+
from ppcls.data.dataloader.mix_dataset import MixDataset
8+
from ppcls.data.dataloader.mix_sampler import MixSampler
9+
from ppcls.data.dataloader.pk_sampler import PKSampler

ppcls/data/dataloader/mix_dataset.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy as np
18+
import os
19+
20+
from paddle.io import Dataset
21+
from .. import dataloader
22+
23+
24+
class MixDataset(Dataset):
25+
def __init__(self, datasets_config):
26+
super().__init__()
27+
self.dataset_list = []
28+
start_idx = 0
29+
end_idx = 0
30+
for config_i in datasets_config:
31+
dataset_name = config_i.pop('name')
32+
dataset = getattr(dataloader, dataset_name)(**config_i)
33+
end_idx += len(dataset)
34+
self.dataset_list.append([end_idx, start_idx, dataset])
35+
start_idx = end_idx
36+
37+
self.length = end_idx
38+
39+
def __getitem__(self, idx):
40+
for dataset_i in self.dataset_list:
41+
if dataset_i[0] > idx:
42+
dataset_i_idx = idx - dataset_i[1]
43+
return dataset_i[2][dataset_i_idx]
44+
45+
def __len__(self):
46+
return self.length
47+
48+
def get_dataset_list(self):
49+
return self.dataset_list

ppcls/data/dataloader/mix_sampler.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
18+
from paddle.io import DistributedBatchSampler, Sampler
19+
20+
from ppcls.utils import logger
21+
from ppcls.data.dataloader.mix_dataset import MixDataset
22+
from ppcls.data import dataloader
23+
24+
25+
class MixSampler(DistributedBatchSampler):
26+
def __init__(self, dataset, batch_size, sample_configs, iter_per_epoch):
27+
super().__init__(dataset, batch_size)
28+
assert isinstance(dataset,
29+
MixDataset), "MixSampler only support MixDataset"
30+
self.sampler_list = []
31+
self.batch_size = batch_size
32+
self.start_list = []
33+
self.length = iter_per_epoch
34+
dataset_list = dataset.get_dataset_list()
35+
batch_size_left = self.batch_size
36+
self.iter_list = []
37+
for i, config_i in enumerate(sample_configs):
38+
self.start_list.append(dataset_list[i][1])
39+
sample_method = config_i.pop("name")
40+
ratio_i = config_i.pop("ratio")
41+
if i < len(sample_configs) - 1:
42+
batch_size_i = int(self.batch_size * ratio_i)
43+
batch_size_left -= batch_size_i
44+
else:
45+
batch_size_i = batch_size_left
46+
assert batch_size_i <= len(dataset_list[i][2])
47+
config_i["batch_size"] = batch_size_i
48+
if sample_method == "DistributedBatchSampler":
49+
sampler_i = DistributedBatchSampler(dataset_list[i][2],
50+
**config_i)
51+
else:
52+
sampler_i = getattr(dataloader, sample_method)(
53+
dataset_list[i][2], **config_i)
54+
self.sampler_list.append(sampler_i)
55+
self.iter_list.append(iter(sampler_i))
56+
self.length += len(dataset_list[i][2]) * ratio_i
57+
self.iter_counter = 0
58+
59+
def __iter__(self):
60+
while self.iter_counter < self.length:
61+
batch = []
62+
for i, iter_i in enumerate(self.iter_list):
63+
batch_i = next(iter_i, None)
64+
if batch_i is None:
65+
iter_i = iter(self.sampler_list[i])
66+
self.iter_list[i] = iter_i
67+
batch_i = next(iter_i, None)
68+
assert batch_i is not None, "dataset {} return None".format(
69+
i)
70+
batch += [idx + self.start_list[i] for idx in batch_i]
71+
if len(batch) == self.batch_size:
72+
self.iter_counter += 1
73+
yield batch
74+
else:
75+
logger.info("Some dataset reaches end")
76+
self.iter_counter = 0
77+
78+
def __len__(self):
79+
return self.length

ppcls/data/dataloader/pk_sampler.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from collections import defaultdict
18+
import numpy as np
19+
import random
20+
from paddle.io import DistributedBatchSampler
21+
22+
from ppcls.utils import logger
23+
24+
25+
class PKSampler(DistributedBatchSampler):
26+
"""
27+
First, randomly sample P identities.
28+
Then for each identity randomly sample K instances.
29+
Therefore batch size is P*K, and the sampler called PKSampler.
30+
Args:
31+
dataset (paddle.io.Dataset): list of (img_path, pid, cam_id).
32+
sample_per_id(int): number of instances per identity in a batch.
33+
batch_size (int): number of examples in a batch.
34+
shuffle(bool): whether to shuffle indices order before generating
35+
batch indices. Default False.
36+
"""
37+
38+
def __init__(self,
39+
dataset,
40+
batch_size,
41+
sample_per_id,
42+
shuffle=True,
43+
drop_last=True,
44+
sample_method="sample_avg_prob"):
45+
super().__init__(
46+
dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
47+
assert batch_size % sample_per_id == 0, \
48+
"PKSampler configs error, Sample_per_id must be a divisor of batch_size."
49+
assert hasattr(self.dataset,
50+
"labels"), "Dataset must have labels attribute."
51+
self.sample_per_label = sample_per_id
52+
self.label_dict = defaultdict(list)
53+
self.sample_method = sample_method
54+
for idx, label in enumerate(self.dataset.labels):
55+
self.label_dict[label].append(idx)
56+
self.label_list = list(self.label_dict)
57+
assert len(self.label_list) * self.sample_per_label > self.batch_size, \
58+
"batch size should be smaller than "
59+
if self.sample_method == "id_avg_prob":
60+
self.prob_list = np.array([1 / len(self.label_list)] *
61+
len(self.label_list))
62+
elif self.sample_method == "sample_avg_prob":
63+
counter = []
64+
for label_i in self.label_list:
65+
counter.append(len(self.label_dict[label_i]))
66+
self.prob_list = np.array(counter) / sum(counter)
67+
else:
68+
logger.error(
69+
"PKSampler only support id_avg_prob and sample_avg_prob sample method, "
70+
"but receive {}.".format(self.sample_method))
71+
if sum(np.abs(self.prob_list - 1) > 0.00000001):
72+
self.prob_list[-1] = 1 - sum(self.prob_list[:-1])
73+
if self.prob_list[-1] > 1 or self.prob_list[-1] < 0:
74+
logger.error("PKSampler prob list error")
75+
else:
76+
logger.info(
77+
"PKSampler: sum of prob list not equal to 1, change the last prob"
78+
)
79+
80+
def __iter__(self):
81+
label_per_batch = self.batch_size // self.sample_per_label
82+
if self.shuffle:
83+
np.random.RandomState(self.epoch).shuffle(self.label_list)
84+
for i in range(len(self)):
85+
batch_index = []
86+
batch_label_list = np.random.choice(
87+
self.label_list,
88+
size=label_per_batch,
89+
replace=False,
90+
p=self.prob_list)
91+
for label_i in batch_label_list:
92+
label_i_indexes = self.label_dict[label_i]
93+
if self.sample_per_label <= len(label_i_indexes):
94+
batch_index.extend(
95+
np.random.choice(
96+
label_i_indexes,
97+
size=self.sample_per_label,
98+
replace=False))
99+
else:
100+
batch_index.extend(
101+
np.random.choice(
102+
label_i_indexes,
103+
size=self.sample_per_label,
104+
replace=True))
105+
if not self.drop_last or len(batch_index) == self.batch_size:
106+
yield batch_index

0 commit comments

Comments
 (0)