Skip to content

Commit 8c6dddf

Browse files
authored
Merge pull request #48 from wjm202/add_from_pretrain
上传了对应的权重url ,修复了部分paddlenlp最新版本不兼容的报错
2 parents ebb3f5b + 6bafb6e commit 8c6dddf

File tree

7 files changed

+240
-228
lines changed

7 files changed

+240
-228
lines changed

paddlevlp/examples/blip2/export.py

+34-34
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414
import sys
1515
import os
16-
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../..'))
16+
sys.path.insert(
17+
0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../..'))
1718
from dataclasses import dataclass, field
1819
import paddle
1920
import requests
@@ -29,6 +30,7 @@
2930
import os
3031
import paddle
3132

33+
3234
@dataclass
3335
class DataArguments:
3436
"""
@@ -38,12 +40,14 @@ class DataArguments:
3840
the command line.
3941
"""
4042

41-
input_image: str = field( default="http://images.cocodataset.org/val2017/000000039769.jpg",
42-
metadata={"help": "The name of input image."}
43-
) # "http://images.cocodataset.org/val2017/000000039769.jpg"
43+
input_image: str = field(
44+
default="http://images.cocodataset.org/val2017/000000039769.jpg",
45+
metadata={"help": "The name of input image."
46+
}) # "http://images.cocodataset.org/val2017/000000039769.jpg"
4447
prompt: str = field(
45-
default=None, metadata={"help": "The prompt of the image to be generated."}
46-
) # "Question: how many cats are there? Answer:"
48+
default=None,
49+
metadata={"help": "The prompt of the image to be generated."
50+
}) # "Question: how many cats are there? Answer:"
4751

4852

4953
@dataclass
@@ -53,50 +57,44 @@ class ModelArguments:
5357
"""
5458

5559
model_name_or_path: str = field(
56-
default="Salesforce/blip2-opt-2.7b",
57-
metadata={"help": "Path to pretrained model or model identifier"},
58-
)
60+
default="paddlemix/blip2-caption-opt2.7b",
61+
metadata={"help": "Path to pretrained model or model identifier"}, )
5962
pretrained_model_path: str = field(
6063
default=None,
6164
metadata={
62-
"help": "The path to pre-trained model that we will use for inference."
63-
},)
65+
"help":
66+
"The path to pre-trained model that we will use for inference."
67+
}, )
6468
fp16: str = field(
6569
default=True,
66-
metadata={
67-
"help": "Export with mixed precision."
68-
},
69-
)
70+
metadata={"help": "Export with mixed precision."}, )
7071

7172

7273
def main():
7374
parser = PdArgumentParser((ModelArguments, DataArguments))
7475
model_args, data_args = parser.parse_args_into_dataclasses()
75-
url = (
76-
data_args.input_image
77-
) # "http://images.cocodataset.org/val2017/000000039769.jpg"
76+
url = (data_args.input_image
77+
) # "http://images.cocodataset.org/val2017/000000039769.jpg"
7878
image = Image.open(requests.get(url, stream=True).raw)
7979

8080
prompt = "a photo of "
81-
processor = Blip2Processor.from_pretrained(
82-
model_args.model_name_or_path
83-
) # "Salesforce/blip2-opt-2.7b"
84-
model = Blip2ForConditionalGeneration.from_pretrained(model_args.model_name_or_path)
81+
processor = Blip2Processor.from_pretrained(model_args.model_name_or_path)
82+
model = Blip2ForConditionalGeneration.from_pretrained(
83+
model_args.model_name_or_path)
8584
model.eval()
86-
dtype="float32"
85+
dtype = "float32"
8786
if model_args.fp16:
8887
decorated = paddle.amp.decorate(
89-
models=[model.visual_encoder,model.language_model], optimizers=None, level="O2"
90-
)
91-
model.visual_encoder,model.language_model= decorated
92-
dtype="float16"
88+
models=[model.visual_encoder, model.language_model],
89+
optimizers=None,
90+
level="O2")
91+
model.visual_encoder, model.language_model = decorated
92+
dtype = "float16"
9393

94-
shape1 = [None,3,None,None]
95-
input_spec = [
96-
paddle.static.InputSpec(
97-
shape=shape1, dtype='float32'),
98-
]
99-
image_encoder = paddle.jit.to_static(model.encode_image, input_spec=input_spec)
94+
shape1 = [None, 3, None, None]
95+
input_spec = [paddle.static.InputSpec(shape=shape1, dtype='float32'), ]
96+
image_encoder = paddle.jit.to_static(
97+
model.encode_image, input_spec=input_spec)
10098
save_path = "blip2_export"
10199
paddle.jit.save(image_encoder, os.path.join(save_path, 'image_encoder'))
102100

@@ -106,7 +104,7 @@ def main():
106104
'model': 'image_encoder.pdmodel',
107105
'params': 'image_encoder.pdiparams',
108106
'input_img_shape': shape1,
109-
'output_dtype':dtype
107+
'output_dtype': dtype
110108
}
111109
}
112110
msg = '\n---------------Deploy Information---------------\n'
@@ -118,5 +116,7 @@ def main():
118116
yaml.dump(deploy_info, file)
119117

120118
logger.info(f'The inference model is saved in {save_path}')
119+
120+
121121
if __name__ == "__main__":
122122
main()

paddlevlp/examples/blip2/run_pretrain_stage2.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,6 @@ class PreTrainingArguments(TrainingArguments):
8585
"""
8686
Arguments pertaining to what training options we are going to use during pretraining.
8787
"""
88-
89-
pretrained_model_path: str = field(
90-
default="https://bj.bcebos.com/v1/paddlenlp/models/community/Salesforce/blip2-opt-2.7b/blip2_pretrained.pdparams",
91-
metadata={
92-
"help":
93-
"The path to pre-trained model that we will use for pretraining."
94-
}, )
9588
weight_decay: float = field(
9689
default=0.05, metadata={"help": "Weight decay if we apply some."})
9790
learning_rate: float = field(
@@ -260,6 +253,7 @@ def main():
260253
eval_processor=eval_processor,
261254
tokenizer=tokenizer_class)
262255
# Training
256+
checkpoint = None
263257
if training_args.model_path is not None:
264258
checkpoint = training_args.model_path
265259
load_model(

paddlevlp/models/blip2/Qformer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1104,7 +1104,9 @@ def __init__(self,
11041104
train_in_satge1=False,
11051105
**kwargs):
11061106
super().__init__(config)
1107-
config.mp_degree = kwargs.get('mp_degree')
1107+
from paddle.distributed import fleet
1108+
config.mp_degree = fleet.DistributedStrategy().hybrid_configs[
1109+
'mp_degree']
11081110
config.encoder_width = encoder_width
11091111
config.gradient_checkpointing = False
11101112
self.ln_vision = paddle.nn.LayerNorm(config.encoder_width)

paddlevlp/models/blip2/configuration.py

+65-81
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
1615
""" BLIP-2 model configuration"""
1716
import copy
1817
import os
@@ -83,59 +82,52 @@ class Blip2VisionConfig(PretrainedConfig):
8382
model_type = "blip_2_vision_model"
8483

8584
def __init__(
86-
self,
87-
img_size=224,
88-
patch_size=14,
89-
embed_dim=1408,
90-
depth=39,
91-
num_heads=16,
92-
mlp_ratio=4.3637,
93-
qkv_bias=True,
94-
drop_rate=0,
95-
epsilon=1e-6,
96-
mp_degree=1,
97-
gradient_checkpointing=False,
98-
**kwargs,
99-
):
85+
self,
86+
img_size=224,
87+
patch_size=14,
88+
embed_dim=1408,
89+
depth=39,
90+
num_heads=16,
91+
mlp_ratio=4.3637,
92+
qkv_bias=True,
93+
drop_rate=0,
94+
epsilon=1e-6,
95+
gradient_checkpointing=False,
96+
**kwargs, ):
10097
kwargs["return_dict"] = kwargs.pop("return_dict", True)
10198
super().__init__(**kwargs)
10299

103100
self.img_size = img_size
104101
self.patch_size = patch_size
105-
self. embed_dim = embed_dim
106-
self.depth= depth
102+
self.embed_dim = embed_dim
103+
self.depth = depth
107104
self.num_heads = num_heads
108105
self.mlp_ratio = mlp_ratio
109106
self.qkv_bias = qkv_bias
110107
self.drop_rate = drop_rate
111108
self.epsilon = epsilon
112-
self.mp_degree = mp_degree
113-
self.gradient_checkpointing= gradient_checkpointing
109+
self.gradient_checkpointing = gradient_checkpointing
114110

115-
self.in_chans =kwargs.get('in_chans', 3)
116-
self.class_num = kwargs.get( 'class_num', 1000)
111+
self.in_chans = kwargs.get('in_chans', 3)
112+
self.class_num = kwargs.get('class_num', 1000)
117113
self.qk_scale = kwargs.get('qk_scale', None)
118-
self.attn_drop_rate = kwargs.get( 'attn_drop_rate=', 0.)
119-
self.drop_path_rate = kwargs.get( 'drop_path_rate', 0.)
120-
self.norm_layer = kwargs.get( 'norm_layer', 'nn.LayerNorm')
114+
self.attn_drop_rate = kwargs.get('attn_drop_rate=', 0.)
115+
self.drop_path_rate = kwargs.get('drop_path_rate', 0.)
116+
self.norm_layer = kwargs.get('norm_layer', 'nn.LayerNorm')
121117

122118
@classmethod
123-
def from_pretrained(
124-
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
125-
) -> "PretrainedConfig":
126-
config_dict, kwargs = cls.get_config_dict(
127-
pretrained_model_name_or_path, **kwargs
128-
)
119+
def from_pretrained(cls,
120+
pretrained_model_name_or_path: Union[str, os.PathLike],
121+
**kwargs) -> "PretrainedConfig":
122+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path,
123+
**kwargs)
129124

130125
# get the vision config dict if we are loading from Blip2Config
131126
if config_dict.get("model_type") == "blip-2":
132127
config_dict = config_dict["vision_config"]
133128

134-
if (
135-
"model_type" in config_dict
136-
and hasattr(cls, "model_type")
137-
and config_dict["model_type"] != cls.model_type
138-
):
129+
if ("model_type" in config_dict and hasattr(cls, "model_type") and
130+
config_dict["model_type"] != cls.model_type):
139131
logger.warning(
140132
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
141133
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
@@ -204,25 +196,24 @@ class Blip2QFormerConfig(PretrainedConfig):
204196
model_type = "blip_2_qformer"
205197

206198
def __init__(
207-
self,
208-
vocab_size=30522,
209-
hidden_size=768,
210-
num_hidden_layers=12,
211-
num_attention_heads=12,
212-
intermediate_size=3072,
213-
hidden_act="gelu",
214-
hidden_dropout_prob=0.1,
215-
attention_probs_dropout_prob=0.1,
216-
max_position_embeddings=512,
217-
initializer_range=0.02,
218-
layer_norm_eps=1e-12,
219-
pad_token_id=0,
220-
position_embedding_type="absolute",
221-
classifier_dropout=None,
222-
cross_attention_frequency=2,
223-
encoder_hidden_size=1408,
224-
**kwargs,
225-
):
199+
self,
200+
vocab_size=30522,
201+
hidden_size=768,
202+
num_hidden_layers=12,
203+
num_attention_heads=12,
204+
intermediate_size=3072,
205+
hidden_act="gelu",
206+
hidden_dropout_prob=0.1,
207+
attention_probs_dropout_prob=0.1,
208+
max_position_embeddings=512,
209+
initializer_range=0.02,
210+
layer_norm_eps=1e-12,
211+
pad_token_id=0,
212+
position_embedding_type="absolute",
213+
classifier_dropout=None,
214+
cross_attention_frequency=2,
215+
encoder_hidden_size=1408,
216+
**kwargs, ):
226217
kwargs["return_dict"] = kwargs.pop("return_dict", True)
227218
super().__init__(pad_token_id=pad_token_id, **kwargs)
228219

@@ -243,22 +234,18 @@ def __init__(
243234
self.encoder_hidden_size = encoder_hidden_size
244235

245236
@classmethod
246-
def from_pretrained(
247-
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
248-
) -> "PretrainedConfig":
249-
config_dict, kwargs = cls.get_config_dict(
250-
pretrained_model_name_or_path, **kwargs
251-
)
237+
def from_pretrained(cls,
238+
pretrained_model_name_or_path: Union[str, os.PathLike],
239+
**kwargs) -> "PretrainedConfig":
240+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path,
241+
**kwargs)
252242

253243
# get the qformer config dict if we are loading from Blip2Config
254244
if config_dict.get("model_type") == "blip-2":
255245
config_dict = config_dict["qformer_config"]
256246

257-
if (
258-
"model_type" in config_dict
259-
and hasattr(cls, "model_type")
260-
and config_dict["model_type"] != cls.model_type
261-
):
247+
if ("model_type" in config_dict and hasattr(cls, "model_type") and
248+
config_dict["model_type"] != cls.model_type):
262249
logger.warning(
263250
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
264251
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
@@ -313,13 +300,12 @@ class Blip2Config(PretrainedConfig):
313300
is_composition = True
314301

315302
def __init__(
316-
self,
317-
vision_config=None,
318-
qformer_config=None,
319-
text_config=None,
320-
num_query_tokens=32,
321-
**kwargs,
322-
):
303+
self,
304+
vision_config=None,
305+
qformer_config=None,
306+
text_config=None,
307+
num_query_tokens=32,
308+
**kwargs, ):
323309
super().__init__(**kwargs)
324310

325311
if vision_config is None:
@@ -341,7 +327,7 @@ def __init__(
341327
)
342328
self.vision_config = vision_config
343329
self.qformer_config = qformer_config
344-
self.text_config = text_config
330+
self.text_config = text_config
345331

346332
# self.use_decoder_only_language_model = (
347333
# self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
@@ -354,12 +340,11 @@ def __init__(
354340

355341
@classmethod
356342
def from_vision_qformer_text_configs(
357-
cls,
358-
vision_config: Blip2VisionConfig,
359-
qformer_config: Blip2QFormerConfig,
360-
text_config: PretrainedConfig,
361-
**kwargs,
362-
):
343+
cls,
344+
vision_config: Blip2VisionConfig,
345+
qformer_config: Blip2QFormerConfig,
346+
text_config: PretrainedConfig,
347+
**kwargs, ):
363348
r"""
364349
Instantiate a [`Blip2Config`] (or a derived class) from a BLIP-2 vision model, Q-Former and language model
365350
configurations.
@@ -371,8 +356,7 @@ def from_vision_qformer_text_configs(
371356
vision_config=vision_config,
372357
qformer_config=qformer_config,
373358
text_config=text_config,
374-
**kwargs,
375-
)
359+
**kwargs, )
376360

377361
def to_dict(self):
378362
"""

0 commit comments

Comments
 (0)