Skip to content

Commit 352c363

Browse files
committed
ViT-B/16 finetune
Top1 Acc: 0.7805, while ViT paper was 0.7791 There are diff: * lr from 0.03 to 0.003 * add 0.0001 weight_decay * global gradient clip from 1.0 to 0.5
1 parent 58c9d17 commit 352c363

6 files changed

+134
-6
lines changed

plsc/configs/VisionTransformer/ViT_base_patch16_224_in1k_4n32c_dp_fp16o1.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ Model:
2929
name: ViT_base_patch16_224
3030
class_num: 1000
3131
drop_rate: 0.1
32-
representation_size: 768
3332

3433
# loss function config for traing/eval process
3534
Loss:

plsc/configs/VisionTransformer/ViT_base_patch16_224_in1k_4n32c_dp_fp16o2.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ Model:
2929
name: ViT_base_patch16_224
3030
class_num: 1000
3131
drop_rate: 0.1
32-
representation_size: 768
3332

3433
# loss function config for traing/eval process
3534
Loss:

plsc/configs/VisionTransformer/ViT_base_patch16_224_in1k_4n32c_so2_fp16o2.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ Model:
2727
name: ViT_base_patch16_224
2828
class_num: 1000
2929
drop_rate: 0.1
30-
representation_size: 768
3130

3231
# loss function config for traing/eval process
3332
Loss:
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# global configs
2+
Global:
3+
checkpoint: null
4+
finetune: True
5+
pretrained_model: ./pretrained/ViT_base_patch16_224/best_model
6+
output_dir: ./output/
7+
device: gpu
8+
save_interval: 1
9+
max_num_latest_checkpoint: 0
10+
eval_during_train: True
11+
eval_interval: 1
12+
eval_unit: "epoch"
13+
accum_steps: 1
14+
epochs: 8
15+
print_batch_step: 10
16+
use_visualdl: False
17+
seed: 2021
18+
19+
# FP16 setting
20+
FP16:
21+
level: O2
22+
GradScaler:
23+
init_loss_scaling: 65536.0
24+
25+
DistributedStrategy:
26+
data_parallel: True
27+
28+
# model architecture
29+
Model:
30+
name: ViT_base_patch16_384
31+
class_num: 1000
32+
drop_rate: 0.1
33+
34+
# loss function config for traing/eval process
35+
Loss:
36+
Train:
37+
- ViTCELoss:
38+
type: softmax
39+
weight: 1.0
40+
Eval:
41+
- CELoss:
42+
weight: 1.0
43+
44+
LRScheduler:
45+
name: ViTLRScheduler
46+
learning_rate: 0.003 # Note(GuoxiaWang): Affect the key role of accuracy.
47+
decay_type: cosine
48+
warmup_steps: 500
49+
50+
Optimizer:
51+
name: Momentum
52+
weight_decay: 0.0001 # Note(GuoxiaWang): Affect the key role of accuracy.
53+
momentum: 0.9
54+
grad_clip:
55+
name: ClipGradByGlobalNorm
56+
clip_norm: 0.5 # Note(GuoxiaWang): Affect the key role of accuracy.
57+
58+
# data loader for train and eval
59+
DataLoader:
60+
Train:
61+
dataset:
62+
name: ImageNetDataset
63+
image_root: ./dataset/ILSVRC2012/
64+
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
65+
transform_ops:
66+
- DecodeImage:
67+
to_rgb: True
68+
channel_first: False
69+
- RandCropImage:
70+
size: 384
71+
scale: [0.05, 1.0]
72+
interpolation: bilinear
73+
backend: pil
74+
- RandFlipImage:
75+
flip_code: 1
76+
- NormalizeImage:
77+
scale: 1.0/255.0
78+
mean: [0.5, 0.5, 0.5]
79+
std: [0.5, 0.5, 0.5]
80+
order: ''
81+
- ToCHWImage:
82+
83+
sampler:
84+
name: DistributedBatchSampler
85+
batch_size: 16 # total batchsize 512
86+
drop_last: True
87+
shuffle: True
88+
loader:
89+
num_workers: 8
90+
use_shared_memory: True
91+
92+
Eval:
93+
dataset:
94+
name: ImageNetDataset
95+
image_root: ./dataset/ILSVRC2012/
96+
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
97+
transform_ops:
98+
- DecodeImage:
99+
to_rgb: True
100+
channel_first: False
101+
- ResizeImage:
102+
size: 384
103+
interpolation: bilinear
104+
backend: pil
105+
- NormalizeImage:
106+
scale: 1.0/255.0
107+
mean: [0.5, 0.5, 0.5]
108+
std: [0.5, 0.5, 0.5]
109+
order: ''
110+
- ToCHWImage:
111+
112+
sampler:
113+
name: DistributedBatchSampler
114+
batch_size: 256
115+
drop_last: False
116+
shuffle: False
117+
loader:
118+
num_workers: 8
119+
use_shared_memory: True
120+
121+
Metric:
122+
Train:
123+
- TopkAcc:
124+
topk: [1, 5]
125+
Eval:
126+
- TopkAcc:
127+
topk: [1, 5]
128+
129+
Export:
130+
export_type: paddle
131+
input_shape: [None, 3, 384, 384]

plsc/configs/VisionTransformer/ViT_large_patch16_384_in1k_ft_4n32c_dp_fp16o2.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Global:
33
checkpoint: null
44
finetune: True
5-
pretrained_model: ./pretrained/vit_jax/imagenet21k-ViT-L_16
5+
pretrained_model: ./pretrained/ViT_large_patch16_224/best_model
66
output_dir: ./output/
77
device: gpu
88
save_interval: 1
@@ -126,4 +126,4 @@ Metric:
126126

127127
Export:
128128
export_type: paddle
129-
input_shape: [None, 3, 224, 224]
129+
input_shape: [None, 3, 384, 384]

plsc/data/dataset/common_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __getitem__(self, idx):
5858
cls_idx = [int(e) for e in self.labels[idx].split(',')]
5959
for idx in cls_idx:
6060
one_hot[idx] = 1.0
61-
return (img, onehot)
61+
return (img, one_hot)
6262
else:
6363
return (img, np.int32(self.labels[idx]))
6464

0 commit comments

Comments
 (0)