Skip to content

Commit eed7176

Browse files
committed
refine code
1 parent 261cd9d commit eed7176

File tree

6 files changed

+203
-615
lines changed

6 files changed

+203
-615
lines changed
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# MiniGPT4 推理加速
2+
3+
本项目提供了基于 MiniGPT4 的推理加速功能,基本的解决思路是将 MiniGPT4 动态图转为静态图,然后基于 PaddleInference 库进行推理加速。
4+
5+
下图展示了 MiniGPT4 的整体模型结构, 可以看到整体上,MiniGPT4的主要部分由 VIT, QFormer 和 Vicuna 模型组成,其中 Vicuna 模型是基于 Llama 训练的,在代码实现中调用的也是Llama代码,为方便描述,忽略不必要的分歧,所以在后续中将语言模型这部分默认描述为Llama。
6+
7+
在本方案中,我们将MiniGPT4 导出为两个子图:VIT 和 QFormer部分导出为一个静态子图, Llama 部分导出为一个子图。后续会结合这两个子图统一做 MiniGPT4 的推理功能。
8+
9+
<center><img src="https://github.com/PaddlePaddle/Paddle/assets/35913314/f0306cb6-4837-4f52-8f57-a0e7e35238f6" /></center>
10+
11+
12+
13+
14+
## 1. 环境准备
15+
### 1.1 基础环境准备:
16+
本项目在以下基础环境进行了验证:
17+
- CUDA: 11.7
18+
- python: 3.11
19+
- paddle: develop版
20+
21+
其中CUDA版本需要>=11.2, 具体Paddle版本可以点击[这里](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/develop/install/pip/linux-pip.html)按需下载。
22+
23+
24+
### 1.2 安装项目库
25+
1. 本项目需要用到 PaddleMIX 和 PaddleNLP 两个库,并且需要下载最新的 develop 版本:
26+
27+
```shell
28+
git clone https://github.com/PaddlePaddle/PaddleNLP.git
29+
git clone https://github.com/PaddlePaddle/PaddleMIX.git
30+
```
31+
32+
2. 安装paddlenlp_ops:
33+
```shell
34+
cd PaddleNLP/csrc
35+
python setup_cuda.py install
36+
```
37+
38+
3. 最后设置相应的环境变量:
39+
```shell
40+
export PYTHONPATH= yourpath/PaddleNLP:yourpath/PaddleMIX
41+
```
42+
43+
### 1.3 特别说明
44+
目前需要修复PaddleNLP和Paddle的部分代码,从而进行MiniGPT4推理加速。这部分功能后续逐步会逐步完善到PaddleNLP和Paddle,但目前如果想使用的话需要手动修改一下。
45+
1. 修改PaddleNLP代码:
46+
参考该[分支代码](https://github.com/1649759610/PaddleNLP/tree/bugfix_minigpt4),依次替换以下文件:
47+
- PaddleNLP/paddlenlp/experimental/transformers/generation_utils.py
48+
- PaddleNLP/paddlenlp/experimental/transformers/llama/modeling.py
49+
- PaddleNLP/llm/export_model.py
50+
51+
2. 修改Paddle代码
52+
进入到Paddle安装目录,打开文件:paddle/static/io.py, 注释第284-287行代码:
53+
```python
54+
if not skip_prune_program:
55+
copy_program = copy_program._prune_with_input(
56+
feeded_var_names=feed_var_names, targets=fetch_vars
57+
)
58+
```
59+
60+
## 2. MiniGPT4 分阶段导出
61+
62+
### 2.1 导出前一部分子图:
63+
请确保在该目录下:PaddleMIX/paddlemix/examples/minigpt4/inference,按照以下命令进行导出:
64+
```
65+
python export_image_encoder.py \
66+
--minigpt4_13b_path "you minigpt4 dir path" \
67+
--save_path "./checkpoints/encode_image/encode_image"
68+
```
69+
70+
**参数说明**:
71+
- minigpt4_13b_path: 存放MiniGPT4的目录名
72+
- save_path: 前一部分模型的导出路径和名称
73+
74+
75+
### 2.2 导出后一部分子图
76+
请进入到目录: PaddleNLP/llm, 按照以下命令进行导出:
77+
```
78+
python export_model.py \
79+
--model_name_or_path "your llama dir path" \
80+
--output_path "your output path" \
81+
--dtype float16 \
82+
--inference_model \
83+
--model_prefix llama \
84+
--model_type llama-img2txt
85+
86+
```
87+
88+
**参数说明**:
89+
- model_name_or_path: 存放Llama模型的目录名
90+
- output_path: 语言模型部分的导出路径和名称
91+
- dtype: 模型权重数据类型
92+
- inference_model: 表示是推理模型
93+
- model_prefix: 指明模型前缀
94+
- model_type: 指明模型类型
95+
96+
**备注**: 当前导出Llama部分需要转移到PaddleNLP下进行手动导出,后续将支持在PaddleMIX下一键转出。
97+
98+
## 3. MiniGPT4 静态图推理
99+
请进入到目录PaddleMIX/paddlemix/examples/minigpt4/inference,执行以下命令:
100+
```python
101+
python run_static_predict.py \
102+
--first_model_path "The dir name of image encoder model" \
103+
--second_model_path "The dir name of language model" \
104+
--minigpt4_path "The minigpt4 dir name of saving tokenizer"
105+
```
106+
107+
**参数说明**:
108+
- first_model_path: 存放前一部分(即vit和qformer)的静态图模型目录名
109+
- second_model_path: 存放后一部分(即语言模型)的静态图模型目录名
110+
- minigpt4_path: 存放 MiniGPT4 tokenizer的目录名
111+
112+
以下展示了针对以下这个图片,MiniGPT4静态图推理的输出:
113+
114+
<center><img src="https://paddlenlp.bj.bcebos.com/data/images/mugs.png" /></center>
115+
116+
```text
117+
Reference: The image shows two black and white cats sitting next to each other on a blue background. The cats have black fur and white fur with black noses, eyes, and paws. They are both looking at the camera with a curious expression. The mugs are also blue with the same design of the cats on them. There is a small white flower on the left side of the mug. The background is a light blue color.
118+
119+
Outputs: ['The image shows two black and white cats sitting next to each other on a blue background. The cats have black fur and white fur with black noses, eyes, and paws. They are both looking at the camera with a curious expression. The mugs are also blue with the same design of the cats on them. There is a small white flower on the left side of the mug. The background is a light blue color.##']
120+
```

paddlemix/examples/minigpt4/inference/export_image_encoder.py renamed to paddlemix/examples/minigpt4/deploy/export_image_encoder.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import argparse
22
import os
3-
os.environ["CUDA_VISIBLE_DEVICES"]="7"
3+
os.environ["CUDA_VISIBLE_DEVICES"]="0"
44
os.environ["FLAGS_use_cuda_managed_memory"]="true"
55

66
import paddle
@@ -41,26 +41,3 @@ def export(args):
4141
args = parser.parse_args()
4242

4343
export(args)
44-
45-
46-
47-
48-
49-
50-
51-
52-
# processor = MiniGPT4Processor.from_pretrained(minigpt4_13b_path)
53-
# print("load processor and model done!")
54-
55-
# # prepare model inputs for MiniGPT4
56-
# url = "https://paddlenlp.bj.bcebos.com/data/images/mugs.png"
57-
# image = Image.open(requests.get(url, stream=True).raw)
58-
59-
# inputs = processor.process_images(image)
60-
# model.
61-
62-
63-
# # generate with MiniGPT4
64-
# outputs = model.generate(**inputs, **generate_kwargs)
65-
# msg = processor.batch_decode(outputs[0])
66-
# print(msg)

paddlemix/examples/minigpt4/inference/run_static_predict.py renamed to paddlemix/examples/minigpt4/deploy/run_static_predict.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
11
import argparse
22
import os
3-
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
3+
import sys
4+
import requests
5+
import numpy as np
6+
import datetime
7+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
48
os.environ["FLAGS_use_cuda_managed_memory"] = "true"
59

6-
710
import paddle
811
from paddle import inference
912
from paddlenlp.transformers import MiniGPT4Processor
1013
from PIL import Image
11-
import requests
1214

13-
import sys
14-
15-
# sys.path.append("/wangqinghui/PaddleNLP/llm")
1615
from utils import load_real_time_tokens
1716

18-
import numpy as np
1917

2018
class Predictor(object):
2119
def __init__(self, args):
@@ -62,7 +60,6 @@ def create_predictor(self, model_path):
6260
# such as initialize the gpu memory, enable tensorrt
6361
config.enable_use_gpu(100, 0)
6462
precision_mode = inference.PrecisionType.Half
65-
# breakpoint()
6663
# 第一个模型是要跑TRT的
6764
if self.args.use_tensorrt:
6865
config.enable_tuned_tensorrt_dynamic_shape(shape_range_file, True)
@@ -74,7 +71,6 @@ def create_predictor(self, model_path):
7471
predictor = paddle.inference.create_predictor(config)
7572
input_handles = [predictor.get_input_handle(name) for name in predictor.get_input_names()]
7673
output_handle = [predictor.get_output_handle(name) for name in predictor.get_output_names()]
77-
# output_handle = predictor.get_output_handle(predictor.get_output_names()[0])
7874

7975
return predictor, input_handles, output_handle
8076

@@ -93,9 +89,6 @@ def generate_with_image_features(self,
9389
first_attention_mask=None,
9490
second_attention_mask=None,
9591
**generate_kwargs, ):
96-
# print("image_attention_mask", image_attention_mask)
97-
# print("first_attention_mask", first_attention_mask)
98-
# print("second_attention_mask", second_attention_mask)
9992
batch, seq,_ = image_features.shape
10093
seq = image_features.shape[1] + first_input_ids.shape[1] + second_input_ids.shape[1]
10194
max_len = 204
@@ -200,32 +193,26 @@ def predict(self, images, text, prompt=None):
200193
predictor = Predictor(args)
201194

202195
url = "https://paddlenlp.bj.bcebos.com/data/images/mugs.png"
203-
#url = "https://paddlenlp.bj.bcebos.com/data/images/female.png"
204196
image = Image.open(requests.get(url, stream=True).raw)
205197

206198
text = "describe this image"
207199
prompt = "Give the following image: <Img>ImageContent</Img>. You will be able to see the image once I provide it to you. Please answer my questions.###Human: <Img><ImageHere></Img> <TextHere>###Assistant:"
208200

209-
# warp up
210-
warm_up_times = 1
211-
repeat_times = 5
201+
# warm up
202+
warm_up_times = 2
203+
repeat_times = 10
212204
for i in range(warm_up_times):
213205
msg = predictor.predict(image, text, prompt)
214206

215-
216207
# 测试50次
217-
import datetime
218208
starttime = datetime.datetime.now()
219-
220209
for i in range(repeat_times):
221210
msg = predictor.predict(image, text, prompt)
222211

223212
endtime = datetime.datetime.now()
224213
duringtime = endtime - starttime
225214
time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0
226215

227-
print(
228-
"Reference: The image shows two black and white cats sitting next to each other on a blue background. The cats have black fur and white fur with black noses, eyes, and paws. They are both looking at the camera with a curious expression. The mugs are also blue with the same design of the cats on them. There is a small white flower on the left side of the mug. The background is a light blue color.")
216+
print("Reference: The image shows two black and white cats sitting next to each other on a blue background. The cats have black fur and white fur with black noses, eyes, and paws. They are both looking at the camera with a curious expression. The mugs are also blue with the same design of the cats on them. There is a small white flower on the left side of the mug. The background is a light blue color.")
229217
print("Outputs: ", msg)
230-
print("infer OK")
231-
print("The whoel end to end time : ", time_ms / repeat_times, "ms")
218+
print("The whole time on average: ", time_ms / repeat_times, "ms")
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import glob
17+
import os
18+
import struct
19+
import numpy as np
20+
21+
22+
def deserialize_from_file(fp):
23+
x_type = fp.read(1)
24+
x_type_out = struct.unpack("c", x_type)[0]
25+
# data
26+
data_list = []
27+
if x_type_out == b"0":
28+
data = fp.read(4)
29+
data_out = struct.unpack("f", data)[0]
30+
while data:
31+
data_out = struct.unpack("f", data)[0]
32+
data_list.append(data_out)
33+
data = fp.read(4)
34+
elif x_type_out == b"1":
35+
data = fp.read(8)
36+
while data:
37+
data_out = struct.unpack("l", data)[0]
38+
data_list.append(data_out)
39+
data = fp.read(8)
40+
elif x_type_out == b"2":
41+
data = fp.read(4)
42+
while data:
43+
data_out = struct.unpack("i", data)[0]
44+
data_list.append(data_out)
45+
data = fp.read(4)
46+
else:
47+
print("type error")
48+
data_arr = np.array(data_list)
49+
return data_arr
50+
51+
def load_real_time_tokens():
52+
tokens = []
53+
files = glob.glob(os.path.join("./real_time_save.*"))
54+
for j in range(1, len(files) + 1):
55+
filename = "./real_time_save.temp_ids_rank_0_step_{}".format(j)
56+
if not os.path.exists(filename):
57+
break
58+
fp = open(filename, "rb+")
59+
fp.read(1)
60+
data_list = deserialize_from_file(fp)
61+
fp.close()
62+
tokens.append(np.array(data_list).reshape(-1, 1))
63+
os.system("rm -f ./real_time_save.temp_ids_rank_*")
64+
tokens = np.concatenate(tokens, axis=1)
65+
return tokens

paddlemix/examples/minigpt4/inference/README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ python setup_cuda.py install
3737

3838
3. 最后设置相应的环境变量:
3939
```shell
40-
export PYTHONPATH=/wangqinghui/PaddleNLP:/wangqinghui/PaddleMIX
40+
export PYTHONPATH= yourpath/PaddleNLP:yourpath/PaddleMIX
4141
```
4242

4343
### 1.3 特别说明
@@ -51,16 +51,16 @@ export PYTHONPATH=/wangqinghui/PaddleNLP:/wangqinghui/PaddleMIX
5151
2. 修改Paddle代码
5252
进入到Paddle安装目录,打开文件:paddle/static/io.py, 注释第284-287行代码:
5353
```python
54-
if not skip_prune_program:
55-
copy_program = copy_program._prune_with_input(
56-
feeded_var_names=feed_var_names, targets=fetch_vars
57-
)
54+
if not skip_prune_program:
55+
copy_program = copy_program._prune_with_input(
56+
feeded_var_names=feed_var_names, targets=fetch_vars
57+
)
5858
```
5959

6060
## 2. MiniGPT4 分阶段导出
6161

6262
### 2.1 导出前一部分子图:
63-
请确保在该目录下:PaddleMIX/paddlemix/examples/minigpt4/inference,按照以下命令进行导出:
63+
请确保在该目录下:PaddleMIX/paddlemix/examples/minigpt4/deploy,按照以下命令进行导出:
6464
```
6565
python export_image_encoder.py \
6666
--minigpt4_13b_path "you minigpt4 dir path" \
@@ -83,7 +83,7 @@ python export_model.py \
8383
**备注**: 当前导出Llama部分需要转移到PaddleNLP下进行手动导出,后续将支持在PaddleMIX下一键转出。
8484

8585
## 3. MiniGPT4 静态图推理
86-
请进入到目录PaddleMIX/paddlemix/examples/minigpt4/inference,执行以下命令:
86+
请进入到目录PaddleMIX/paddlemix/examples/minigpt4/deploy,执行以下命令:
8787
```python
8888
python run_static_predict.py \
8989
--first_model_path "The dir name of image encoder model" \

0 commit comments

Comments
 (0)