Skip to content

Commit 843f05d

Browse files
cocosheLokeZhou
andauthored
add Cogagent (PaddlePaddle#445)
Co-authored-by: LokeZhou <aishenghuoaiqq@163.com>
1 parent fd5135a commit 843f05d

File tree

6 files changed

+2360
-0
lines changed

6 files changed

+2360
-0
lines changed

paddlemix/examples/cogagent/README.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# CogAgent
2+
3+
## 1. 模型介绍
4+
5+
该模型是 [CogAgent](https://arxiv.org/abs/2312.08914) 的 paddle 实现。
6+
7+
[CogAgent](https://arxiv.org/abs/2312.08914)是一个基于CogVLM改进的开源视觉语言模型。CogAgent-18B拥有110亿的视觉参数和70亿的语言参数。
8+
9+
CogAgent-18B在9个经典的跨模态基准测试中实现了最先进的全能性能,包括VQAv2、OK-VQ、TextVQA、ST-VQA、ChartQA、infoVQA、DocVQA、MM-Vet和POPE。
10+
11+
除了CogVLM已有的所有功能(视觉多轮对话,视觉定位)之外,CogAgent:
12+
13+
1. 支持更高分辨率的视觉输入和对话式问答。它支持超高分辨率的图像输入,达到1120x1120。
14+
15+
2. 拥有视觉Agent的能力,能够在任何图形用户界面截图上,为任何给定任务返回一个计划,下一步行动,以及带有坐标的特定操作。
16+
17+
3. 增强了与图形用户界面相关的问答能力,使其能够处理关于任何图形用户界面截图的问题,例如网页、PC应用、移动应用等。
18+
19+
4. 通过改进预训练和微调,提高了OCR相关任务的能力。
20+
21+
本仓库提供paddle版本的 cogagent-chat 模型
22+
23+
## 2. 环境准备
24+
25+
1) [安装PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP?tab=readme-ov-file#%E5%AE%89%E8%A3%85)
26+
27+
2)[安装 PaddleMix 环境依赖包](https://github.com/PaddlePaddle/PaddleMIX/tree/b4f97ff859e1964c839fc5fab94f7ba63b1e5959?tab=readme-ov-file#%E5%AE%89%E8%A3%85)
28+
29+
## 3. 快速开始
30+
完成环境准备后,我们目前提供多轮对话方式使用:
31+
32+
```bash
33+
python paddlemix/examples/cogagent/chat_demo.py \
34+
--from_pretrained "THUDM/cogagent-chat"
35+
```
36+
37+
可配置参数说明:
38+
* `from_pretrained`: 指定CogAgent的模型名字或权重路径以及tokenizer,默认 THUDM/cogagent-chat
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import random
16+
17+
import numpy as np
18+
import paddle
19+
20+
seed = 2024
21+
paddle.seed(seed)
22+
np.random.seed(seed)
23+
random.seed(seed)
24+
25+
import argparse
26+
27+
from PIL import Image
28+
29+
from paddlemix.auto.modeling import AutoModelMIX
30+
from paddlemix.auto.tokenizer import AutoTokenizerMIX
31+
32+
parser = argparse.ArgumentParser()
33+
34+
parser.add_argument("--from_pretrained", type=str, default="THUDM/cogagent-chat", help="pretrained ckpt and tokenizer")
35+
args = parser.parse_args()
36+
MODEL_PATH = args.from_pretrained
37+
TOKENIZER_PATH = MODEL_PATH
38+
39+
tokenizer = AutoTokenizerMIX.from_pretrained(TOKENIZER_PATH)
40+
41+
data_type = "float32"
42+
43+
model = AutoModelMIX.from_pretrained(
44+
MODEL_PATH,
45+
dtype=data_type,
46+
low_cpu_mem_usage=False,
47+
)
48+
model.eval()
49+
50+
text_only_template = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"
51+
while True:
52+
image_path = input("image path >>>>> ")
53+
if image_path == "":
54+
print("You did not enter image path, the following will be a plain text conversation.")
55+
image = None
56+
text_only_first_query = True
57+
else:
58+
image = Image.open(image_path).convert("RGB")
59+
history = []
60+
while True:
61+
query = input("Human:")
62+
if query == "clear":
63+
break
64+
if image is None:
65+
if text_only_first_query:
66+
query = text_only_template.format(query)
67+
text_only_first_query = False
68+
else:
69+
old_prompt = ""
70+
for _, (old_query, response) in enumerate(history):
71+
old_prompt += old_query + " " + response + "\n"
72+
query = old_prompt + "USER: {} ASSISTANT:".format(query)
73+
if image is None:
74+
input_by_model = model.build_conversation_input_ids(
75+
tokenizer, query=query, history=history, template_version="base"
76+
)
77+
else:
78+
input_by_model = model.build_conversation_input_ids(
79+
tokenizer, query=query, history=history, images=[image]
80+
)
81+
inputs = {
82+
"input_ids": input_by_model["input_ids"].unsqueeze(axis=0),
83+
"token_type_ids": input_by_model["token_type_ids"].unsqueeze(axis=0),
84+
"attention_mask": input_by_model["attention_mask"].unsqueeze(axis=0),
85+
"images": [[input_by_model["images"][0].to(data_type)]] if image is not None else None,
86+
}
87+
if "cross_images" in input_by_model and input_by_model["cross_images"]:
88+
inputs["cross_images"] = [[input_by_model["cross_images"][0].to(data_type)]]
89+
gen_kwargs = {"max_new_tokens": 2048, "do_sample": False}
90+
with paddle.no_grad():
91+
outputs, _ = model.generate(**inputs, **gen_kwargs)
92+
outputs = outputs[:, inputs["input_ids"].shape[1] :]
93+
response = tokenizer.decode(outputs[0])
94+
response = response.split("</s>")[0]
95+
print("\nCog:", response)
96+
history.append((query, response))
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Literal
16+
17+
from paddlenlp import transformers
18+
19+
20+
class CogAgentConfig(transformers.PretrainedConfig):
21+
_auto_class = "AutoConfig"
22+
23+
def __init__(
24+
self,
25+
vocab_size=32000,
26+
hidden_size=4096,
27+
cross_hidden_size=1024,
28+
cross_compute_hidden_size=1024,
29+
cross_image_size=1120,
30+
intermediate_size=11008,
31+
num_hidden_layers=32,
32+
num_attention_heads=32,
33+
hidden_act="silu",
34+
max_position_embeddings=2048,
35+
initializer_range=0.02,
36+
rms_norm_eps=1e-06,
37+
template_version: Literal["base", "chat"] = "chat",
38+
pad_token_id=0,
39+
bos_token_id=1,
40+
eos_token_id=2,
41+
tie_word_embeddings=False,
42+
use_cache=True,
43+
**kwargs
44+
):
45+
self.hidden_size = hidden_size
46+
self.cross_hidden_size = cross_hidden_size
47+
self.cross_compute_hidden_size = cross_compute_hidden_size
48+
self.cross_image_size = cross_image_size
49+
self.intermediate_size = intermediate_size
50+
self.num_attention_heads = num_attention_heads
51+
self.max_position_embeddings = max_position_embeddings
52+
self.rms_norm_eps = rms_norm_eps
53+
self.initializer_range = initializer_range
54+
self.vocab_size = vocab_size
55+
self.num_hidden_layers = num_hidden_layers
56+
self.hidden_act = hidden_act
57+
self.template_version = template_version
58+
self.use_cache = use_cache
59+
super().__init__(
60+
pad_token_id=pad_token_id,
61+
bos_token_id=bos_token_id,
62+
eos_token_id=eos_token_id,
63+
tie_word_embeddings=tie_word_embeddings,
64+
**kwargs,
65+
)

0 commit comments

Comments
 (0)