Skip to content

Commit 58c9d17

Browse files
committed
ViT-L/16 finetune code
* Add 0.0001 weight decay for Momentum optimizer * Top1 Acc 85.03% based on jax checkpoint
1 parent 8c48b20 commit 58c9d17

11 files changed

+596
-116
lines changed

plsc/configs/VisionTransformer/ViT_large_patch16_224_in22k_4n32c_dp_fp16o2.yaml

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ DistributedStrategy:
2727
# model architecture
2828
Model:
2929
name: ViT_large_patch16_224
30-
class_num: 10450
30+
class_num: 21841
3131
drop_rate: 0.1
32-
representation_size: 768
3332

3433
# loss function config for traing/eval process
3534
Loss:
@@ -53,20 +52,16 @@ Optimizer:
5352
epsilon: 1e-8
5453
weight_decay: 0.03
5554
exp_avg_force_fp32: True
56-
grad_clip:
57-
name: ClipGradByGlobalNorm
58-
clip_norm: 1.0
5955

6056

6157
# data loader for train and eval
6258
DataLoader:
6359
Train:
6460
dataset:
6561
name: ImageNetDataset
66-
image_root: ./dataset/ImageNet21K/
67-
multi_label: True
68-
class_num: 10450
69-
cls_label_path: ./dataset/ImageNet21K/multi_label_train_list.txt
62+
image_root: ./dataset/ImageNet22K/
63+
class_num: 21841
64+
cls_label_path: ./dataset/ImageNet22K/train_list.txt
7065
transform_ops:
7166
- DecodeImage:
7267
to_rgb: True
@@ -97,8 +92,8 @@ DataLoader:
9792
Eval:
9893
dataset:
9994
name: ImageNetDataset
100-
image_root: ./dataset/ImageNet21K/
101-
cls_label_path: ./dataset/ImageNet21K/val_list.txt
95+
image_root: ./dataset/ImageNet22K/
96+
cls_label_path: ./dataset/ImageNet22K/val_list.txt
10297
transform_ops:
10398
- DecodeImage:
10499
to_rgb: True
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# global configs
2+
Global:
3+
checkpoint: null
4+
finetune: True
5+
pretrained_model: ./pretrained/vit_jax/imagenet21k-ViT-L_16
6+
output_dir: ./output/
7+
device: gpu
8+
save_interval: 1
9+
max_num_latest_checkpoint: 0
10+
eval_during_train: True
11+
eval_interval: 1
12+
eval_unit: "epoch"
13+
accum_steps: 8
14+
epochs: 8
15+
print_batch_step: 10
16+
use_visualdl: False
17+
seed: 2021
18+
19+
# FP16 setting
20+
FP16:
21+
level: O2
22+
GradScaler:
23+
init_loss_scaling: 65536.0
24+
25+
DistributedStrategy:
26+
data_parallel: True
27+
28+
# model architecture
29+
Model:
30+
name: ViT_large_patch16_384
31+
class_num: 1000
32+
drop_rate: 0.1
33+
34+
# loss function config for traing/eval process
35+
Loss:
36+
Train:
37+
- ViTCELoss:
38+
type: softmax
39+
weight: 1.0
40+
Eval:
41+
- CELoss:
42+
weight: 1.0
43+
44+
LRScheduler:
45+
name: ViTLRScheduler
46+
learning_rate: 0.03
47+
decay_type: cosine
48+
warmup_steps: 500
49+
50+
Optimizer:
51+
name: Momentum
52+
momentum: 0.9
53+
weight_decay: 0.0001
54+
grad_clip:
55+
name: ClipGradByGlobalNorm
56+
clip_norm: 1.0
57+
58+
59+
# data loader for train and eval
60+
DataLoader:
61+
Train:
62+
dataset:
63+
name: ImageNetDataset
64+
image_root: ./dataset/ILSVRC2012/
65+
class_num: 1000
66+
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
67+
transform_ops:
68+
- DecodeImage:
69+
to_rgb: True
70+
channel_first: False
71+
- RandCropImage:
72+
size: 384
73+
scale: [0.05, 1.0]
74+
interpolation: bilinear
75+
backend: pil
76+
- RandFlipImage:
77+
flip_code: 1
78+
- NormalizeImage:
79+
scale: 1.0/255.0
80+
mean: [0.5, 0.5, 0.5]
81+
std: [0.5, 0.5, 0.5]
82+
order: ''
83+
- ToCHWImage:
84+
85+
sampler:
86+
name: DistributedBatchSampler
87+
batch_size: 64 # total batchsize 512
88+
drop_last: True
89+
shuffle: True
90+
loader:
91+
num_workers: 8
92+
use_shared_memory: True
93+
94+
Eval:
95+
dataset:
96+
name: ImageNetDataset
97+
image_root: ./dataset/ILSVRC2012/
98+
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
99+
transform_ops:
100+
- DecodeImage:
101+
to_rgb: True
102+
channel_first: False
103+
- ResizeImage:
104+
size: 384
105+
interpolation: bilinear
106+
backend: pil
107+
- NormalizeImage:
108+
scale: 1.0/255.0
109+
mean: [0.5, 0.5, 0.5]
110+
std: [0.5, 0.5, 0.5]
111+
order: ''
112+
- ToCHWImage:
113+
114+
sampler:
115+
name: DistributedBatchSampler
116+
batch_size: 256
117+
drop_last: False
118+
shuffle: False
119+
loader:
120+
num_workers: 8
121+
use_shared_memory: True
122+
123+
Metric:
124+
Eval:
125+
- TopkAcc:
126+
topk: [1, 5]
127+
128+
Export:
129+
export_type: paddle
130+
input_shape: [None, 3, 224, 224]
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# global configs
2+
Global:
3+
checkpoint: null
4+
finetune: True
5+
pretrained_model: ./pretrained/vit_jax/imagenet21k-ViT-L_16
6+
output_dir: ./output/
7+
device: gpu
8+
save_interval: 1
9+
max_num_latest_checkpoint: 0
10+
eval_during_train: True
11+
eval_interval: 1
12+
eval_unit: "epoch"
13+
accum_steps: 1
14+
epochs: 8
15+
print_batch_step: 10
16+
use_visualdl: False
17+
seed: 2021
18+
19+
# FP16 setting
20+
FP16:
21+
level: O2
22+
GradScaler:
23+
init_loss_scaling: 65536.0
24+
25+
DistributedStrategy:
26+
data_parallel: True
27+
28+
# model architecture
29+
Model:
30+
name: ViT_large_patch16_384
31+
class_num: 1000
32+
drop_rate: 0.1
33+
34+
# loss function config for traing/eval process
35+
Loss:
36+
Train:
37+
- ViTCELoss:
38+
type: softmax
39+
weight: 1.0
40+
Eval:
41+
- CELoss:
42+
weight: 1.0
43+
44+
LRScheduler:
45+
name: ViTLRScheduler
46+
learning_rate: 0.03
47+
decay_type: cosine
48+
warmup_steps: 500
49+
50+
Optimizer:
51+
name: Momentum
52+
momentum: 0.9
53+
weight_decay: 0.0001
54+
grad_clip:
55+
name: ClipGradByGlobalNorm
56+
clip_norm: 1.0
57+
58+
# data loader for train and eval
59+
DataLoader:
60+
Train:
61+
dataset:
62+
name: ImageNetDataset
63+
image_root: ./dataset/ILSVRC2012/
64+
class_num: 1000
65+
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
66+
transform_ops:
67+
- DecodeImage:
68+
to_rgb: True
69+
channel_first: False
70+
- RandCropImage:
71+
size: 384
72+
scale: [0.05, 1.0]
73+
interpolation: bilinear
74+
backend: pil
75+
- RandFlipImage:
76+
flip_code: 1
77+
- NormalizeImage:
78+
scale: 1.0/255.0
79+
mean: [0.5, 0.5, 0.5]
80+
std: [0.5, 0.5, 0.5]
81+
order: ''
82+
- ToCHWImage:
83+
84+
sampler:
85+
name: DistributedBatchSampler
86+
batch_size: 16 # total batchsize 512
87+
drop_last: True
88+
shuffle: True
89+
loader:
90+
num_workers: 8
91+
use_shared_memory: True
92+
93+
Eval:
94+
dataset:
95+
name: ImageNetDataset
96+
image_root: ./dataset/ILSVRC2012/
97+
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
98+
transform_ops:
99+
- DecodeImage:
100+
to_rgb: True
101+
channel_first: False
102+
- ResizeImage:
103+
size: 384
104+
interpolation: bilinear
105+
backend: pil
106+
- NormalizeImage:
107+
scale: 1.0/255.0
108+
mean: [0.5, 0.5, 0.5]
109+
std: [0.5, 0.5, 0.5]
110+
order: ''
111+
- ToCHWImage:
112+
113+
sampler:
114+
name: DistributedBatchSampler
115+
batch_size: 256
116+
drop_last: False
117+
shuffle: False
118+
loader:
119+
num_workers: 8
120+
use_shared_memory: True
121+
122+
Metric:
123+
Eval:
124+
- TopkAcc:
125+
topk: [1, 5]
126+
127+
Export:
128+
export_type: paddle
129+
input_shape: [None, 3, 224, 224]

plsc/core/grad_clip.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import warnings
1616
import paddle
1717
from paddle import _legacy_C_ops as _C_ops
18+
from plsc.utils import logger
1819

1920

2021
def _squared_l2_norm(x):
@@ -94,3 +95,56 @@ def __call__(self, params):
9495
'Y': clip_coef},
9596
outputs={'Out': param.grad},
9697
attrs={'axis': -1})
98+
99+
100+
@paddle.no_grad()
101+
def clip_grad_norm_(parameters,
102+
max_norm: float,
103+
norm_type: float=2.0,
104+
error_if_nonfinite: bool=False):
105+
r"""Clips gradient norm of an iterable of parameters.
106+
107+
The norm is computed over all gradients together, as if they were
108+
concatenated into a single vector. Gradients are modified in-place.
109+
110+
Args:
111+
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
112+
single Tensor that will have gradients normalized
113+
max_norm (float or int): max norm of the gradients
114+
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
115+
infinity norm.
116+
error_if_nonfinite (bool): if True, an error is thrown if the total
117+
norm of the gradients from :attr:``parameters`` is ``nan``,
118+
``inf``, or ``-inf``. Default: False (will switch to True in the future)
119+
120+
Returns:
121+
Total norm of the parameters (viewed as a single vector).
122+
"""
123+
if isinstance(parameters, paddle.Tensor):
124+
parameters = [parameters]
125+
parameters = [p for p in parameters if p.grad is not None]
126+
max_norm = float(max_norm)
127+
norm_type = float(norm_type)
128+
if len(parameters) == 0:
129+
return paddle.to_tensor([0.])
130+
131+
total_norm = paddle.norm(
132+
paddle.stack([paddle.norm(p.grad, norm_type) for p in parameters]),
133+
norm_type)
134+
if error_if_nonfinite and paddle.logical_or(total_norm.isnan(),
135+
total_norm.isinf()):
136+
raise RuntimeError(
137+
f'The total norm of order {norm_type} for gradients from '
138+
'`parameters` is non-finite, so it cannot be clipped. To disable '
139+
'this error and scale the gradients by the non-finite norm anyway, '
140+
'set `error_if_nonfinite=False`')
141+
clip_coef = max_norm / (total_norm + 1e-6)
142+
clip_coef_clamped = paddle.clip(clip_coef, max=1.0)
143+
for p in parameters:
144+
paddle.fluid.framework._dygraph_tracer().trace_op(
145+
type="elementwise_mul",
146+
inputs={'X': p.grad,
147+
'Y': clip_coef_clamped},
148+
outputs={'Out': p.grad},
149+
attrs={'axis': -1})
150+
return total_norm

0 commit comments

Comments
 (0)