Skip to content

Commit 5d1ab55

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleClas into slim
2 parents 9f03478 + 4103640 commit 5d1ab55

File tree

100 files changed

+3314
-2340
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

100 files changed

+3314
-2340
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
---
2+
name: 问题反馈
3+
about: PaddleClas问题反馈
4+
title: ''
5+
labels: ''
6+
assignees: ''
7+
8+
---
9+
10+
欢迎您使用PaddleClas并反馈相关问题,非常感谢您对PaddleClas的贡献!
11+
提出issue时,辛苦您提供以下信息,方便我们快速定位问题并及时有效地解决您的问题:
12+
1. PaddleClas版本以及PaddlePaddle版本:请您提供您使用的版本号或分支信息,如PaddleClas release/2.2和PaddlePaddle 2.1.0
13+
2. 涉及的其他产品使用的版本号:如您在使用PaddleClas的同时还在使用其他产品,如PaddleServing、PaddleInference等,请您提供其版本号
14+
3. 训练环境信息:
15+
a. 具体操作系统,如Linux/Windows/MacOS
16+
b. Python版本号,如Python3.6/7/8
17+
c. CUDA/cuDNN版本, 如CUDA10.2/cuDNN 7.6.5等
18+
4. 完整的代码(相比于repo中代码,有改动的地方)、详细的错误信息及相关log

README_ch.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
**近期更新**
1010

11+
- 2021.08.11 更新7个[FAQ](docs/zh_CN/faq_series/faq_2021_s2.md)
1112
- 2021.06.29 添加Swin-transformer系列模型,ImageNet1k数据集上Top1 acc最高精度可达87.2%;支持训练预测评估与whl包部署,预训练模型可以从[这里](docs/zh_CN/models/models_intro.md)下载。
1213
- 2021.06.22,23,24 PaddleClas官方研发团队带来技术深入解读三日直播课。课程回放:[https://aistudio.baidu.com/aistudio/course/introduce/24519](https://aistudio.baidu.com/aistudio/course/introduce/24519)
1314
- 2021.06.16 PaddleClas v2.2版本升级,集成Metric learning,向量检索等组件。新增商品识别、动漫人物识别、车辆识别和logo识别等4个图像识别应用。新增LeViT、Twins、TNT、DLA、HarDNet、RedNet系列30个预训练模型。
@@ -50,6 +51,10 @@ Res2Net200_vd预训练模型Top-1精度高达85.1%。
5051
- [图像识别快速体验](./docs/zh_CN/tutorials/quick_start_recognition.md)
5152
- [图像识别系统介绍](#图像识别系统介绍)
5253
- [识别效果展示](#识别效果展示)
54+
- 图像分类快速体验
55+
- [尝鲜版](./docs/zh_CN/tutorials/quick_start_new_user.md)
56+
- [进阶版](./docs/zh_CN/tutorials/quick_start_professional.md)
57+
- [社区版](./docs/zh_CN/tutorials/quick_start_community.md)
5358
- 算法介绍
5459
- [骨干网络和预训练模型库](./docs/zh_CN/ImageNet_models_cn.md)
5560
- [主体检测](./docs/zh_CN/application/mainbody_detection.md)
@@ -74,11 +79,14 @@ Res2Net200_vd预训练模型Top-1精度高达85.1%。
7479
- [知识蒸馏](./docs/zh_CN/advanced_tutorials/distillation/distillation.md)
7580
- [模型量化](./docs/zh_CN/extension/paddle_quantization.md)
7681
- [数据增广](./docs/zh_CN/advanced_tutorials/image_augmentation/ImageAugment.md)
77-
- FAQ(暂停更新)
82+
- FAQ
83+
- [图像识别任务FAQ](docs/zh_CN/faq_series/faq_2021_s2.md)
7884
- [图像分类任务FAQ](docs/zh_CN/faq.md)
7985
- [许可证书](#许可证书)
8086
- [贡献代码](#贡献代码)
8187

88+
<a name="图像识别系统介绍"></a>
89+
## 图像识别系统介绍
8290

8391
<div align="center">
8492
<img src="./docs/images/structure.png" width = "400" />

deploy/configs/build_cartoon.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Global:
22
rec_inference_model_dir: "./models/cartoon_rec_ResNet50_iCartoon_v1.0_infer/"
3-
batch_size: 1
3+
batch_size: 32
44
use_gpu: True
55
enable_mkldnn: True
66
cpu_num_threads: 10

deploy/configs/build_logo.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Global:
22
rec_inference_model_dir: "./models/logo_rec_ResNet50_Logo3K_v1.0_infer/"
3-
batch_size: 1
3+
batch_size: 32
44
use_gpu: True
55
enable_mkldnn: True
66
cpu_num_threads: 10

deploy/configs/build_product.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Global:
22
rec_inference_model_dir: "./models/product_ResNet50_vd_aliproduct_v1.0_infer"
3-
batch_size: 1
3+
batch_size: 32
44
use_gpu: True
55
enable_mkldnn: True
66
cpu_num_threads: 10

deploy/configs/build_vehicle.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Global:
22
rec_inference_model_dir: "./models/vehicle_cls_ResNet50_CompCars_v1.0_infer/"
3-
batch_size: 1
3+
batch_size: 32
44
use_gpu: True
55
enable_mkldnn: True
66
cpu_num_threads: 10

deploy/python/predict_cls.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,29 @@ def __init__(self, config):
4141
if "PostProcess" in config:
4242
self.postprocess = build_postprocess(config["PostProcess"])
4343

44+
# for whole_chain project to test each repo of paddle
45+
self.benchmark = config["Global"].get("benchmark", False)
46+
if self.benchmark:
47+
import auto_log
48+
import os
49+
pid = os.getpid()
50+
self.auto_logger = auto_log.AutoLogger(
51+
model_name=config["Global"].get("model_name", "cls"),
52+
model_precision='fp16'
53+
if config["Global"]["use_fp16"] else 'fp32',
54+
batch_size=config["Global"].get("batch_size", 1),
55+
data_shape=[3, 224, 224],
56+
save_path=config["Global"].get("save_log_path",
57+
"./auto_log.log"),
58+
inference_config=self.config,
59+
pids=pid,
60+
process_name=None,
61+
gpu_ids=None,
62+
time_keys=[
63+
'preprocess_time', 'inference_time', 'postprocess_time'
64+
],
65+
warmup=2)
66+
4467
def predict(self, images):
4568
input_names = self.paddle_predictor.get_input_names()
4669
input_tensor = self.paddle_predictor.get_input_handle(input_names[0])
@@ -49,29 +72,67 @@ def predict(self, images):
4972
output_tensor = self.paddle_predictor.get_output_handle(output_names[
5073
0])
5174

75+
if self.benchmark:
76+
self.auto_logger.times.start()
5277
if not isinstance(images, (list, )):
5378
images = [images]
5479
for idx in range(len(images)):
5580
for ops in self.preprocess_ops:
5681
images[idx] = ops(images[idx])
5782
image = np.array(images)
83+
if self.benchmark:
84+
self.auto_logger.times.stamp()
5885

5986
input_tensor.copy_from_cpu(image)
6087
self.paddle_predictor.run()
6188
batch_output = output_tensor.copy_to_cpu()
89+
if self.benchmark:
90+
self.auto_logger.times.stamp()
91+
if self.postprocess is not None:
92+
batch_output = self.postprocess(batch_output)
93+
if self.benchmark:
94+
self.auto_logger.times.end(stamp=True)
6295
return batch_output
6396

6497

6598
def main(config):
6699
cls_predictor = ClsPredictor(config)
67100
image_list = get_image_list(config["Global"]["infer_imgs"])
68101

69-
assert config["Global"]["batch_size"] == 1
70-
for idx, image_file in enumerate(image_list):
71-
img = cv2.imread(image_file)[:, :, ::-1]
72-
output = cls_predictor.predict(img)
73-
output = cls_predictor.postprocess(output, [image_file])
74-
print(output)
102+
batch_imgs = []
103+
batch_names = []
104+
cnt = 0
105+
for idx, img_path in enumerate(image_list):
106+
img = cv2.imread(img_path)
107+
if img is None:
108+
logger.warning(
109+
"Image file failed to read and has been skipped. The path: {}".
110+
format(img_path))
111+
else:
112+
img = img[:, :, ::-1]
113+
batch_imgs.append(img)
114+
img_name = os.path.basename(img_path)
115+
batch_names.append(img_name)
116+
cnt += 1
117+
118+
if cnt % config["Global"]["batch_size"] == 0 or (idx + 1
119+
) == len(image_list):
120+
if len(batch_imgs) == 0:
121+
continue
122+
123+
batch_results = cls_predictor.predict(batch_imgs)
124+
for number, result_dict in enumerate(batch_results):
125+
filename = batch_names[number]
126+
clas_ids = result_dict["class_ids"]
127+
scores_str = "[{}]".format(", ".join("{:.2f}".format(
128+
r) for r in result_dict["scores"]))
129+
label_names = result_dict["label_names"]
130+
print("{}:\tclass id(s): {}, score(s): {}, label_name(s): {}".
131+
format(filename, clas_ids, scores_str, label_names))
132+
batch_imgs = []
133+
batch_names = []
134+
if cls_predictor.benchmark:
135+
cls_predictor.auto_logger.report()
75136
return
76137

77138

deploy/python/predict_rec.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,48 @@ def predict(self, images, feature_normalize=True):
5454
input_tensor.copy_from_cpu(image)
5555
self.paddle_predictor.run()
5656
batch_output = output_tensor.copy_to_cpu()
57-
57+
5858
if feature_normalize:
5959
feas_norm = np.sqrt(
6060
np.sum(np.square(batch_output), axis=1, keepdims=True))
6161
batch_output = np.divide(batch_output, feas_norm)
62-
62+
63+
if self.postprocess is not None:
64+
batch_output = self.postprocess(batch_output)
6365
return batch_output
6466

6567

6668
def main(config):
6769
rec_predictor = RecPredictor(config)
6870
image_list = get_image_list(config["Global"]["infer_imgs"])
6971

70-
assert config["Global"]["batch_size"] == 1
71-
for idx, image_file in enumerate(image_list):
72-
batch_input = []
73-
img = cv2.imread(image_file)[:, :, ::-1]
74-
output = rec_predictor.predict(img)
75-
if rec_predictor.postprocess is not None:
76-
output = rec_predictor.postprocess(output)
77-
print(output)
72+
batch_imgs = []
73+
batch_names = []
74+
cnt = 0
75+
for idx, img_path in enumerate(image_list):
76+
img = cv2.imread(img_path)
77+
if img is None:
78+
logger.warning(
79+
"Image file failed to read and has been skipped. The path: {}".
80+
format(img_path))
81+
else:
82+
img = img[:, :, ::-1]
83+
batch_imgs.append(img)
84+
img_name = os.path.basename(img_path)
85+
batch_names.append(img_name)
86+
cnt += 1
87+
88+
if cnt % config["Global"]["batch_size"] == 0 or (idx + 1) == len(image_list):
89+
if len(batch_imgs) == 0:
90+
continue
91+
92+
batch_results = rec_predictor.predict(batch_imgs)
93+
for number, result_dict in enumerate(batch_results):
94+
filename = batch_names[number]
95+
print("{}:\t{}".format(filename, result_dict))
96+
batch_imgs = []
97+
batch_names = []
98+
7899
return
79100

80101

deploy/utils/predictor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, args, inference_model_dir=None):
2828
if args.use_fp16 is True:
2929
assert args.use_tensorrt is True
3030
self.args = args
31-
self.paddle_predictor = self.create_paddle_predictor(
31+
self.paddle_predictor, self.config = self.create_paddle_predictor(
3232
args, inference_model_dir)
3333

3434
def predict(self, image):
@@ -59,11 +59,12 @@ def create_paddle_predictor(self, args, inference_model_dir=None):
5959
config.enable_tensorrt_engine(
6060
precision_mode=Config.Precision.Half
6161
if args.use_fp16 else Config.Precision.Float32,
62-
max_batch_size=args.batch_size)
62+
max_batch_size=args.batch_size,
63+
min_subgraph_size=30)
6364

6465
config.enable_memory_optim()
6566
# use zero copy
6667
config.switch_use_feed_fetch_ops(False)
6768
predictor = create_predictor(config)
6869

69-
return predictor
70+
return predictor, config

deploy/vector_search/README_en.md

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Vector search
2+
3+
## 1. Introduction
4+
5+
Some vertical domain recognition tasks (e.g., vehicles, commodities, etc.) require a large number of recognized categories, and often use a retrieval-based approach to obtain matching predicted categories by performing a fast nearest neighbor search with query vectors and underlying library vectors. The vector search module provides the basic approximate nearest neighbor search algorithm based on Baidu's self-developed Möbius algorithm, a graph-based approximate nearest neighbor search algorithm for maximum inner product search (MIPS). This module provides python interface, supports numpy and tensor type vectors, and supports L2 and Inner Product distance calculation.
6+
7+
Details of the Mobius algorithm can be found in the paper.([Möbius Transformation for Fast Inner Product Search on Graph](http://research.baidu.com/Public/uploads/5e189d36b5cf6.PDF), [Code](https://github.com/sunbelbd/mobius)
8+
9+
## 2. Installation
10+
11+
### 2.1 Use the provided library files directly
12+
13+
This folder contains the compiled `index.so` (compiled under gcc8.2.0 for Linux) and `index.dll` (compiled under gcc10.3.0 for Windows), which can be used directly, skipping sections 2.2 and 2.3.
14+
15+
If the library files are not available due to a low gcc version or an incompatible environment, you need to manually compile the library files under a different platform.
16+
17+
**Note:** Make sure that C++ compiler supports the C++11 standard.
18+
19+
### 2.2 Compile and generate library files on Linux
20+
21+
Run the following command to install gcc and g++.
22+
23+
```
24+
sudo apt-get update
25+
sudo apt-get upgrade -y
26+
sudo apt-get install build-essential gcc g++
27+
```
28+
29+
Check the gcc version by the command `gcc -v`.
30+
31+
`make` can be operated directly. If you wish to regenerate the `index.so`, you can first use `make clean` to clear the cache, and then use `make` to generate the updated library file.
32+
33+
### 2.3 Compile and generate library files on Windows
34+
35+
You need to install gcc compiler tool first, we recommend using [TDM-GCC](https://jmeubank.github.io/tdm-gcc/articles/2020-03/9.2.0-release), you can choose the right version on the official website. We recommend downloading [tdm64-gcc-10.3.0-2.exe](https://github.com/jmeubank/tdm-gcc/releases/download/v10.3.0-tdm64-2/tdm64-gcc-10.3.0-2.exe).
36+
37+
After the downloading, follow the default installation steps to install. There are 3 points to note here:
38+
39+
1. The vector search module depends on openmp, so you need to check the `openmp` installation option when going on to `choose components` step, otherwise it will report an error `libgomp.spec: No such file or directory`, [reference link](https://github.com/dmlc/xgboost/issues/1027)
40+
2. When being asked whether to add to the system environment variables, it is recommended to check here, otherwise you need to add the system environment variables manually later.
41+
3. The compile command is `make` on Linux and `mingw32-make` on Windows, so you need to distinguish here.
42+
43+
After installation, you can open a command line terminal and check the gcc version with the command `gcc -v`.
44+
45+
Run the command `mingw32-make` to generate the `index.dll` library file under the folder (deploy/vector_search). If you want to regenerate the `index.dll` file, you can first use `mingw32-make clean` to clear the cache, and then use `mingw32-make` to generate the updated library file.
46+
47+
### 2.4 Compile and generate library files on MacOS
48+
49+
Run the following command to install gcc and g++:
50+
51+
```
52+
brew install gcc
53+
```
54+
55+
#### Caution:
56+
57+
1. If prompted with `Error: Running Homebrew as root is extremely dangerous and no longer supported... `, refer to this [link](https://jingyan.baidu.com/article/e52e3615057a2840c60c519c.html)
58+
2. If prompted with `Error: Failure while executing; tar --extract --no-same-owner --file... `, refer to this [link](https://blog.csdn.net/Dawn510/article/details/117787358).
59+
60+
After installation the compiled executable is copied under /usr/local/bin, look at the gcc in this folder:
61+
62+
```
63+
ls /usr/local/bin/gcc*
64+
```
65+
66+
The local gcc version is gcc-11, and the compile command is as follows: (If the local gcc version is gcc-9, the corresponding command should be `CXX=g++-9 make`)
67+
68+
```
69+
CXX=g++-11 make
70+
```
71+
72+
## 3. Quick use
73+
74+
```
75+
import numpy as np
76+
from interface import Graph_Index
77+
78+
# Random sample generation
79+
index_vectors = np.random.rand(100000,128).astype(np.float32)
80+
query_vector = np.random.rand(128).astype(np.float32)
81+
index_docs = ["ID_"+str(i) for i in range(100000)]
82+
83+
# Initialize index structure
84+
indexer = Graph_Index(dist_type="IP") #support "IP" and "L2"
85+
indexer.build(gallery_vectors=index_vectors, gallery_docs=index_docs, pq_size=100, index_path='test')
86+
87+
# Query
88+
scores, docs = indexer.search(query=query_vector, return_k=10, search_budget=100)
89+
print(scores)
90+
print(docs)
91+
92+
# Save and load
93+
indexer.dump(index_path="test")
94+
indexer.load(index_path="test")
95+
```

0 commit comments

Comments
 (0)