Skip to content

Commit 16f9ca5

Browse files
authored
add vit-b/16 finetune (#729) (#786)
1 parent f5568c3 commit 16f9ca5

File tree

3 files changed

+154
-2
lines changed

3 files changed

+154
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
_base_: ../base.yaml
2+
3+
Global:
4+
device: gpu
5+
seed: 2021
6+
7+
Engine:
8+
run_mode: 'epoch'
9+
num_train_epochs: 8
10+
eval_freq: 1
11+
accumulate_steps: 1
12+
logging_freq: 10
13+
mix_precision:
14+
use_pure_fp16: True
15+
scale_loss: 32768.0
16+
custom_black_list: ["reduce_sum", "elementwise_div"]
17+
custom_white_list: []
18+
save_load:
19+
save_epoch: 1
20+
output_dir: ./output
21+
ckpt_dir:
22+
23+
Distributed:
24+
dp_degree:
25+
26+
Model:
27+
module: "GeneralClsModule"
28+
model:
29+
name: "ViT_base_patch16_384"
30+
class_num: 1000
31+
drop_rate: 0.1
32+
pretrained:
33+
prefix_path: ./pretrained/vit/imagenet2012-ViT-B_16-224
34+
finetune: True
35+
loss:
36+
train:
37+
name: 'CELoss'
38+
eval:
39+
name: 'CELoss'
40+
metric:
41+
train:
42+
name: 'TopkAcc'
43+
topk: [1, 5]
44+
eval:
45+
name: 'TopkAcc'
46+
topk: [1, 5]
47+
48+
Optimizer:
49+
name: Momentum
50+
weight_decay: 0.0001
51+
momentum: 0.9
52+
lr:
53+
name: ViTLRScheduler
54+
learning_rate: 0.03
55+
decay_type: cosine
56+
warmup_steps: 500
57+
grad_clip:
58+
name: "ClipGradByGlobalNorm"
59+
clip_norm: 1.0
60+
61+
Data:
62+
Train:
63+
dataset:
64+
name: GeneralClsDataset
65+
image_root: ./dataset/ILSVRC2012/
66+
class_num: 1000
67+
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
68+
transform_ops:
69+
- DecodeImage:
70+
to_rgb: True
71+
channel_first: False
72+
- RandCropImage:
73+
size: 384
74+
scale: [0.05, 1.0]
75+
interpolation: bilinear
76+
backend: pil
77+
- RandFlipImage:
78+
flip_code: 1
79+
- NormalizeImage:
80+
scale: 1.0/255.0
81+
mean: [0.5, 0.5, 0.5]
82+
std: [0.5, 0.5, 0.5]
83+
order: ''
84+
- ToCHWImage:
85+
86+
sampler:
87+
name: DistributedBatchSampler
88+
batch_size: 32 # total batchsize 512
89+
drop_last: True
90+
shuffle: True
91+
loader:
92+
num_workers: 8
93+
use_shared_memory: True
94+
95+
Eval:
96+
dataset:
97+
name: GeneralClsDataset
98+
image_root: ./dataset/ILSVRC2012/
99+
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
100+
transform_ops:
101+
- DecodeImage:
102+
to_rgb: True
103+
channel_first: False
104+
- ResizeImage:
105+
size: 384
106+
interpolation: bilinear
107+
backend: pil
108+
- NormalizeImage:
109+
scale: 1.0/255.0
110+
mean: [0.5, 0.5, 0.5]
111+
std: [0.5, 0.5, 0.5]
112+
order: ''
113+
- ToCHWImage:
114+
115+
sampler:
116+
name: DistributedBatchSampler
117+
batch_size: 256
118+
drop_last: False
119+
shuffle: False
120+
loader:
121+
num_workers: 8
122+
use_shared_memory: True

projects/vit/README.md

+16-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
This project implements the (Vision Transformer) proposed by google [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929).
44

55

6-
## How to pretrain from scratch on imagenet 1k
6+
## How to pretrain from scratch on imagenet2012
77

88
### Go to the main repo directory
99
All commands are executed in the home directory.
@@ -36,6 +36,20 @@ Note: ViT-B/16 needs run on 2 nodes with 16 A100 GPUs. If you only have a low-me
3636

3737
The following commands need to be run on each node.
3838
```shell
39-
4039
python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c ppfleetx/configs/vis/vit/ViT_base_patch16_224_pt_in1k_2n16c_dp_fp16o2.yaml
4140
```
41+
42+
## How to finetune on imagenet2012
43+
Finetune is similar to pre-training on ImageNet2012 dataset, we have provided the configured yaml file.
44+
45+
```shell
46+
python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c ppfleetx/configs/vis/vit/ViT_base_patch16_384_ft_in1k_2n16c_dp_fp16o2.yaml
47+
```
48+
49+
50+
## Model
51+
52+
| Model | Phase | Size | Dataset | Resolution | GPUs | Img/sec | Top1 Acc | Pre-trained checkpoint | Fine-tuned checkpoint | Log |
53+
|----------|----------|--------|--------------|------------|-------------|---------|----------|----------------------------------------------------------------------------------------------------|-----------------------|------------------------------------------------------------------------------------------|
54+
| ViT-B_16 | pretrain | 174MiB | ImageNet2012 | 224 | A100*N2C16 | 7350 | 74.55% | [download](https://paddlefleetx.bj.bcebos.com/model/vision/vit/imagenet2012-ViT-B_16-224.pdparams) | - | [log](https://paddlefleetx.bj.bcebos.com/model/vision/vit/imagenet2012-ViT-B_16-224.log) |
55+
| ViT-B_16 | finetune | | ImageNet2012 | 384 | A100*N2C16 | 1363 | | [download](https://paddlefleetx.bj.bcebos.com/model/vision/vit/imagenet2012-ViT-B_16-224.pdparams) | | |

projects/vit/run_finetune.sh

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
16+
python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c ppfleetx/configs/vis/vit/ViT_base_patch16_384_ft_in1k_2n16c_dp_fp16o2.yaml

0 commit comments

Comments
 (0)