Skip to content

Commit 369f581

Browse files
zhaop-llyuwenyu
authored andcommitted
add points_qwen2_5
1 parent 92f21be commit 369f581

File tree

10 files changed

+1708
-0
lines changed

10 files changed

+1708
-0
lines changed
272 KB
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# MiniCPM-V-2_6
2+
3+
## 1. 模型介绍
4+
5+
[POINTS-Qwen](https://huggingface.co/WePOINTS/POINTS-Qwen-2-5-7B-Chat) 融合了视觉语言模型的最新研究进展,并采用了微信AI团队提出的前沿创新技术。
6+
7+
- **强大的基线**:将视觉-语言模型领域的最新进展,即CapFusion、双视觉编码器和动态高分辨率技术,整合到POINTS中
8+
9+
- **预训练数据集过滤**:提出使用困惑度(perplexity)作为指标来过滤预训练数据集。通过这种过滤策略,可以显著减少预训练数据集的规模,同时提升模型的性能。
10+
11+
- **模型融合(Model Soup)**:提出对使用不同视觉指令微调数据集进行微调的模型应用模型融合技术,这可以进一步显著提升模型的性能。
12+
13+
**本仓库支持的模型权重:**
14+
15+
| Model |
16+
|--------------------|
17+
| WePOINTS/POINTS-Qwen-2-5-7B-Chat |
18+
19+
20+
## 2 环境准备
21+
1)[安装PaddlePaddle](https://github.com/PaddlePaddle/PaddleMIX?tab=readme-ov-file#3-%EF%B8%8F%E5%AE%89%E8%A3%85paddlepaddle)
22+
- **python >= 3.10**
23+
- **paddlepaddle-gpu 要求是3.0.0b2或develop版本**
24+
```bash
25+
# 提供三种 PaddlePaddle 安装命令示例,也可参考PaddleMIX主页的安装教程进行安装
26+
27+
# 3.0.0b2版本安装示例 (CUDA 11.8)
28+
python -m pip install paddlepaddle-gpu==3.0.0b2 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/
29+
30+
# Develop 版本安装示例
31+
python -m pip install paddlepaddle-gpu==0.0.0.post118 -f https://www.paddlepaddle.org.cn/whl/linux/gpu/develop.html
32+
33+
# sh 脚本快速安装
34+
sh build_paddle_env.sh
35+
```
36+
37+
2)[安装PaddleMIX环境依赖包](https://github.com/PaddlePaddle/PaddleMIX?tab=readme-ov-file#3-%EF%B8%8F%E5%AE%89%E8%A3%85paddlepaddle)
38+
- **paddlenlp >= 3.0.0b3**
39+
40+
```bash
41+
# 提供两种 PaddleMIX 依赖安装命令示例
42+
43+
# pip 安装示例,安装paddlemix、ppdiffusers、项目依赖、paddlenlp
44+
python -m pip install -e . --user
45+
python -m pip install -e ppdiffusers --user
46+
python -m pip install -r requirements.txt --user
47+
python -m pip install paddlenlp==3.0.0b3 --user
48+
49+
# sh 脚本快速安装
50+
sh build_env.sh
51+
```
52+
53+
> 注:
54+
* 请确保安装了以上依赖,否则无法运行。同时,需要安装 paddlemix/external_ops 下的自定义OP, `python setup.py install`。如果安装后仍然找不到算子,需要额外设置PYTHONPATH
55+
* (默认开启flash_attn)使用flash_attn 要求A100/A800显卡或者H20显卡。V100请用float16推理。
56+
57+
## 3 快速开始
58+
59+
### 推理
60+
```bash
61+
# 单图推理
62+
python paddlemix/examples/points_qwen2_5/image_infer.py
63+
```
64+
65+
### 参考文献
66+
```BibTeX
67+
@article{liu2024points,
68+
title={POINTS: Improving Your Vision-language Model with Affordable Strategies},
69+
author={Liu, Yuan and Zhao, Zhongyin and Zhuang, Ziyuan and Tian, Le and Zhou, Xiao and Zhou, Jie},
70+
journal={arXiv preprint arXiv:2409.04828},
71+
year={2024}
72+
}
73+
74+
@article{liu2024rethinking,
75+
title={Rethinking Overlooked Aspects in Vision-Language Models},
76+
author={Liu, Yuan and Tian, Le and Zhou, Xiao and Zhou, Jie},
77+
journal={arXiv preprint arXiv:2405.11850},
78+
year={2024}
79+
}
80+
81+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
# @Time : 2025/4/19 下午8:37
18+
# @Author : zhaop-l(zhaopuzxjc@126.com)
19+
20+
from paddlenlp.transformers import CLIPImageProcessor, Qwen2Tokenizer
21+
from PIL import Image
22+
23+
from paddlemix.models.points_qwen2_5 import POINTSChatModel
24+
25+
model_path = "WePOINTS/POINTS-Qwen-2-5-7B-Chat"
26+
27+
model = POINTSChatModel.from_pretrained(model_path)
28+
tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
29+
image_processor = CLIPImageProcessor.from_pretrained(model_path)
30+
31+
image_path = "paddlemix/demo_images/minicpm_demo.jpeg"
32+
pil_image = Image.open(image_path)
33+
prompt = "please describe the image in detail"
34+
35+
generation_config = {
36+
"max_new_tokens": 1024,
37+
"temperature": 0.0,
38+
"top_p": 0.0,
39+
"num_beams": 1,
40+
}
41+
res = model.chat(pil_image, prompt, tokenizer, image_processor, True, generation_config)
42+
43+
print(res)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
# @Time : 2025/4/19 下午8:37
18+
# @Author : zhaop-l(zhaopuzxjc@126.com)
19+
from .modeling_points_chat import *
+223
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
# @Time : 2025/4/19 下午8:37
18+
# @Author : zhaop-l(zhaopuzxjc@126.com)
19+
20+
import os
21+
from typing import List, Tuple
22+
23+
from PIL import Image
24+
25+
from .dynamic_high_resolution import factorize_number
26+
27+
28+
def construct_mapping_dict(max_splits: int = 12) -> dict:
29+
"""Construct a mapping dictionary for the given max_splits.
30+
31+
Args:
32+
max_splits (int, optional): The maximum number of splits.
33+
Defaults to 12.
34+
35+
Returns:
36+
dict: A mapping dictionary for the given max_splits.
37+
"""
38+
mapping_dict = {}
39+
for i in range(1, max_splits + 1):
40+
factor_list = factorize_number(i)
41+
for factor in factor_list:
42+
ratio = factor[0] / factor[1]
43+
if ratio not in mapping_dict:
44+
mapping_dict[ratio] = [factor]
45+
else:
46+
mapping_dict[ratio].append(factor)
47+
return mapping_dict
48+
49+
50+
def save_image_list(image_list: List[Image.Image], save_folder: str) -> None:
51+
"""Save a list of images to a folder.
52+
53+
Args:
54+
image_list (List[Image.Image]): A list of images.
55+
save_folder (str): The folder to save the images to.
56+
"""
57+
os.makedirs(save_folder, exist_ok=True)
58+
for i, image in enumerate(image_list):
59+
image.save(os.path.join(save_folder, f"{i}.png"))
60+
61+
62+
def resize_to_best_size(
63+
image: Image.Image,
64+
best_slices: tuple,
65+
width_slices: int,
66+
height_slices: int,
67+
sub_image_size: int,
68+
) -> Image.Image:
69+
"""Resize an image to the best size for the given number of slices.
70+
71+
Args:
72+
image (Image.Image): The image to resize.
73+
best_slices (tuple): The best number of slices for the image.
74+
width_slices (int): The number of horizontal slices.
75+
height_slices (int): The number of vertical slices.
76+
sub_image_size (int): The size of the sub-images.
77+
78+
Returns:
79+
Image.Image: The resized image.
80+
"""
81+
width, height = image.size
82+
best_width_slices, best_height_slices = best_slices
83+
if width_slices < height_slices:
84+
new_image_width = best_width_slices * sub_image_size
85+
new_image_height = int(height / width * new_image_width)
86+
else:
87+
new_image_height = best_height_slices * sub_image_size
88+
new_image_width = int(width / height * new_image_height)
89+
new_image = image.resize((new_image_width, new_image_height), resample=2)
90+
return new_image
91+
92+
93+
def compute_strides(height: int, width: int, sub_image_size: int, slices: Tuple[int, int]) -> Tuple[int, int]:
94+
"""Compute the strides for the given image size and slices.
95+
96+
Args:
97+
height (int): The height of the image.
98+
width (int): The width of the image.
99+
sub_image_size (int): The size of the sub-images.
100+
slices (Tuple[int, int]): The number of horizontal and vertical slices.
101+
102+
Returns:
103+
Tuple[int, int]: The strides for the given image size and slices.
104+
"""
105+
slice_width, slice_height = slices
106+
if slice_width > 1:
107+
stride_x = (width - sub_image_size) // (slice_width - 1)
108+
else:
109+
stride_x = 0
110+
if slice_height > 1:
111+
stride_y = (height - sub_image_size) // (slice_height - 1)
112+
else:
113+
stride_y = 0
114+
return stride_x, stride_y
115+
116+
117+
def sliding_window_crop(image: Image.Image, window_size: int, slices: Tuple[int, int]) -> List[Image.Image]:
118+
"""Crop an image into sub-images using a sliding window.
119+
120+
Args:
121+
image (Image.Image): The image to crop.
122+
window_size (int): The size of the sub-images.
123+
slices (Tuple[int, int]): The number of horizontal and vertical slices.
124+
125+
Returns:
126+
List[Image]: A list of cropped images.
127+
"""
128+
width, height = image.size
129+
stride_x, stride_y = compute_strides(height, width, window_size, slices)
130+
sub_images = []
131+
if stride_x == 0:
132+
stride_x = window_size
133+
if stride_y == 0:
134+
stride_y = window_size
135+
for y in range(0, height - window_size + 1, stride_y):
136+
for x in range(0, width - window_size + 1, stride_x):
137+
sub_image = image.crop((x, y, x + window_size, y + window_size))
138+
sub_images.append(sub_image)
139+
return sub_images
140+
141+
142+
def find_best_slices(width_slices: int, height_slices: int, aspect_ratio: float, max_splits: int = 12) -> list:
143+
"""Find the best slices for the given image size and aspect ratio.
144+
145+
Args:
146+
width_slices (int): The number of horizontal slices.
147+
height_slices (int): The number of vertical slices.
148+
aspect_ratio (float): The aspect ratio of the image.
149+
max_splits (int, optional): The maximum number of splits.
150+
Defaults to 12.
151+
152+
Returns:
153+
list: the best slices for the given image.
154+
"""
155+
mapping_dict = construct_mapping_dict(max_splits)
156+
if aspect_ratio < 1:
157+
mapping_dict = {k: v for k, v in mapping_dict.items() if k <= aspect_ratio}
158+
elif aspect_ratio > 1:
159+
mapping_dict = {k: v for k, v in mapping_dict.items() if k >= aspect_ratio}
160+
best_ratio = min(mapping_dict.keys(), key=lambda x: abs(x - aspect_ratio))
161+
best_image_sizes = mapping_dict[best_ratio]
162+
best_slices = min(best_image_sizes, key=lambda x: abs(x[0] * x[1] - width_slices * height_slices))
163+
return best_slices
164+
165+
166+
def split_image_with_catty(
167+
pil_image: Image.Image,
168+
image_size: int = 336,
169+
max_crop_slices: int = 8,
170+
save_folder: str = None,
171+
add_thumbnail: bool = True,
172+
do_resize: bool = False,
173+
**kwargs,
174+
) -> List[Image.Image]:
175+
"""Split an image into sub-images using Catty.
176+
177+
Args:
178+
pil_image (Image.Image): The image to split.
179+
image_size (int, optional): The size of the image.
180+
Defaults to 336.
181+
max_crop_slices (int, optional): The maximum number of slices.
182+
Defaults to 8.
183+
save_folder (str, optional): The folder to save the sub-images.
184+
Defaults to None.
185+
add_thumbnail (bool, optional): Whether to add a thumbnail.
186+
Defaults to False.
187+
do_resize (bool, optional): Whether to resize the image to fit the
188+
maximum number of slices. Defaults to False.
189+
190+
Returns:
191+
List[Image.Image]: A list of cropped images.
192+
"""
193+
width, height = pil_image.size
194+
ratio = width / height
195+
if ratio > max_crop_slices or ratio < 1 / max_crop_slices:
196+
if do_resize:
197+
print(f"Resizing image to fit maximum number of slices ({max_crop_slices})")
198+
if width > height:
199+
new_width = max_crop_slices * height
200+
new_height = height
201+
else:
202+
new_width = width
203+
new_height = max_crop_slices * width
204+
pil_image = pil_image.resize((new_width, new_height), resample=2)
205+
width, height = pil_image.size
206+
ratio = width / height
207+
else:
208+
print(
209+
f"Image aspect ratio ({ratio:.2f}) is out of range: ({1 / max_crop_slices:.2f}, {max_crop_slices:.2f})"
210+
)
211+
return None
212+
width_slices = width / image_size
213+
height_slices = height / image_size
214+
best_slices = find_best_slices(width_slices, height_slices, ratio, max_crop_slices)
215+
pil_image = resize_to_best_size(pil_image, best_slices, width_slices, height_slices, image_size)
216+
width, height = pil_image.size
217+
sub_images = sliding_window_crop(pil_image, image_size, best_slices)
218+
if add_thumbnail:
219+
thumbnail_image = pil_image.resize((image_size, image_size), resample=2)
220+
sub_images.append(thumbnail_image)
221+
if save_folder is not None:
222+
save_image_list(sub_images, save_folder)
223+
return sub_images

0 commit comments

Comments
 (0)