From 9424cfa2ebe413ca265ff5879fe1ce2a81bace73 Mon Sep 17 00:00:00 2001 From: zhangyikun02 <1129622649@qq.com> Date: Thu, 23 Jan 2025 14:45:50 +0800 Subject: [PATCH 1/2] [XPU]optimize Qwen2_vl for xpu --- .../examples/qwen2_vl/qwen2vl_finetune.py | 25 ++- .../qwen2_vl/shell/baseline_7b_bs32_1e8.sh | 26 ++- .../models/qwen2_vl/configuration_qwen2_vl.py | 2 +- .../models/qwen2_vl/modeling_qwen2_vl.py | 196 ++++++++++++------ 4 files changed, 177 insertions(+), 72 deletions(-) diff --git a/paddlemix/examples/qwen2_vl/qwen2vl_finetune.py b/paddlemix/examples/qwen2_vl/qwen2vl_finetune.py index 43a3e976b..402fd9f02 100644 --- a/paddlemix/examples/qwen2_vl/qwen2vl_finetune.py +++ b/paddlemix/examples/qwen2_vl/qwen2vl_finetune.py @@ -20,7 +20,7 @@ import sys import traceback from dataclasses import dataclass, field -from typing import Dict, Optional, Sequence, Any +from typing import Any, Dict, Optional, Sequence import numpy as np import paddle @@ -31,6 +31,7 @@ from paddlenlp.trainer import PdArgumentParser, TrainingArguments, set_seed from paddlenlp.trainer.trainer import Trainer from paddlenlp.trainer.trainer_utils import get_last_checkpoint +from paddlenlp.transformers.processing_utils import ProcessorMixin from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError from paddlemix.datasets.internvl_dataset import ConcatDataset, WeightedConcatDataset @@ -42,7 +43,6 @@ Qwen2VLImageProcessor, Qwen2VLProcessor, ) -from paddlenlp.transformers.processing_utils import ProcessorMixin Image.MAX_IMAGE_PIXELS = None ImageFile.LOAD_TRUNCATED_IMAGES = True @@ -355,7 +355,7 @@ def pure_text_get_item(self, data_item): attention_mask=attention_mask, images=[], ) - + return ret def __getitem__(self, i) -> Dict[str, paddle.Tensor]: @@ -460,7 +460,7 @@ def __post_init__(self): def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tensor"]: batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], [] - + for feature in features: images = feature.pop("images", None) or [] videos = feature.pop("videos", None) or [] @@ -470,9 +470,7 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens batch_vidlens.append(len(videos)) batch_input_ids.append(feature["input_ids"]) - if ( - self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0 - ): + if self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0: fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}] fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))] fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor) @@ -480,7 +478,7 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens fake_input_ids, _ = self.template.mm_plugin.process_token_ids( fake_input_ids, None, fake_images, [], self.tokenizer, self.processor ) - + if self.tokenizer.padding_side == "right": features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids) @@ -530,7 +528,6 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens return features - def main(): parser = PdArgumentParser((ModelArguments, DataTrainingArguments, PreTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): @@ -565,6 +562,16 @@ def main(): "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." ) + if paddle.is_compiled_with_xpu() and training_args.gradient_accumulation_steps > 1: + try: + from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401 + + LinearConfig.enable_accumulate_steps_opt() + LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps) + except ImportError: + # It's OK, not use accumulate_steps optimization + pass + # Load model if "npu" in paddle.get_device(): is_bfloat16_supported = True diff --git a/paddlemix/examples/qwen2_vl/shell/baseline_7b_bs32_1e8.sh b/paddlemix/examples/qwen2_vl/shell/baseline_7b_bs32_1e8.sh index 76717d960..f53aa8ab3 100644 --- a/paddlemix/examples/qwen2_vl/shell/baseline_7b_bs32_1e8.sh +++ b/paddlemix/examples/qwen2_vl/shell/baseline_7b_bs32_1e8.sh @@ -1,11 +1,11 @@ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -38,6 +38,25 @@ MASTER='127.0.0.1:8080' meta_path="paddlemix/examples/qwen2_vl/configs/baseline_6data_330k.json" +### XPU ### +export XPU_CDNN_CLUSTER_PARALLEL=1 +export XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER=2 +export XPU_PADDLE_FUSE_SHARDING_BUFFER=1 +export FLAGS_use_stride_kernel="0" +# export XPU_PADDLE_L3_SIZE=98566144 # 94 MB +# export XBLAS_FC_AUTOTUNE_FILE="/zhangyikun02/PaddleMIX/autotune_qwen2_vl_7b" +export BKCL_TREE_THRESHOLD=0 + +export XPU_FUSE_RMSNorm=1 +export XPU_FUSE_ATTN_QKV=1 +export XPU_FUSE_FFN=1 +export XPU_FUSE_ROPE=1 +# export PRINT_TIMMER=1 +# export PROFILER=1 + +# export XPUAPI_DEBUG=1 +# export XPURT_DISPATCH_MODE=PROFILING + TRAINING_PYTHON="python -m paddle.distributed.launch --master ${MASTER} --nnodes 1 --nproc_per_node ${GPUS} --rank 0 --ips ${TRAINER_INSTANCES} --run_mode=collective" ${TRAINING_PYTHON} --log_dir ${OUTPUT_DIR}/paddle_distributed_logs \ paddlemix/examples/qwen2_vl/qwen2vl_finetune.py \ @@ -73,6 +92,7 @@ ${TRAINING_PYTHON} --log_dir ${OUTPUT_DIR}/paddle_distributed_logs \ --report_to "visualdl" \ --tensor_parallel_degree=${tensor_parallel_degree} \ --sharding_parallel_degree=${sharding_parallel_degree} \ + --sharding_parallel_config "split_param" \ --pipeline_parallel_degree=1 \ --sep_parallel_degree=1 \ --sharding="stage1" \ diff --git a/paddlemix/models/qwen2_vl/configuration_qwen2_vl.py b/paddlemix/models/qwen2_vl/configuration_qwen2_vl.py index 1e49ca28c..cceb99304 100644 --- a/paddlemix/models/qwen2_vl/configuration_qwen2_vl.py +++ b/paddlemix/models/qwen2_vl/configuration_qwen2_vl.py @@ -31,7 +31,7 @@ def __init__( depth=32, embed_dim=1280, hidden_size=3584, - hidden_act="quick_gelu", + hidden_act="gelu", mlp_ratio=4, num_heads=16, in_channels=3, diff --git a/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py b/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py index 0a06f6c11..930135ee7 100644 --- a/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py +++ b/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py @@ -20,23 +20,26 @@ """Paddle Qwen2-VL model.""" import math +import os from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union + import paddle +import paddle.distributed.fleet.meta_parallel as mpu import paddle.nn as nn import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddlenlp.transformers import linear_utils from paddlenlp.transformers.configuration_utils import PretrainedConfig +from paddlenlp.transformers.linear_utils import Linear from paddlenlp.transformers.model_outputs import BaseModelOutputWithPast, ModelOutput from paddlenlp.transformers.model_utils import PretrainedModel -from paddlenlp.transformers import linear_utils -from paddlenlp.transformers.linear_utils import Linear -from paddle.distributed import fleet -from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker -import paddle.distributed.fleet.meta_parallel as mpu - -from paddle import Tensor, nn +# import paddlenlp.transformers.linear_utils as linear_utils +from paddlenlp.utils.tools import get_env_device from paddlemix.models.flash_attn_utils import ( create_attention_module, @@ -48,6 +51,15 @@ from .bert_padding import index_first_axis, pad_input, unpad_input from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig +if get_env_device() == "xpu": + from paddle_xpu.layers.linear_utils.Linear import xpu_matmul +else: + xpu_matmul = None +try: + from paddle.incubate.nn.functional import fused_rotary_position_embedding +except ImportError: + fused_rotary_position_embedding = None + logger = logging.get_logger(__name__) flash_attn_func, flash_attn_varlen_func = has_flash_attn_func() @@ -66,6 +78,7 @@ def get_triangle_upper_mask(x, mask=None): mask.stop_gradient = True return mask + def parallel_matmul(x: Tensor, y: Tensor, transpose_y=True, tensor_parallel_output=True): is_fleet_init = True tensor_parallel_degree = 1 @@ -82,7 +95,7 @@ def parallel_matmul(x: Tensor, y: Tensor, transpose_y=True, tensor_parallel_outp y_is_distributed = tensor_parallel_degree > 1 if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed: - + # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group) logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y) @@ -94,7 +107,7 @@ def parallel_matmul(x: Tensor, y: Tensor, transpose_y=True, tensor_parallel_outp else: logits = paddle.matmul(x, y, transpose_y=transpose_y) return logits - + def _compute_default_rope_parameters( config: Optional[PretrainedConfig] = None, @@ -307,6 +320,7 @@ def _dynamic_frequency_update(self, position_ids, device): if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset self.inv_freq = self.original_inv_freq self.max_seq_len_cached = self.original_max_seq_len + @paddle.no_grad() def forward(self, x, position_ids): if "dynamic" in self.rope_type: @@ -410,6 +424,31 @@ def apply_rotary_pos_emb_vision(tensor: paddle.Tensor, freqs: paddle.Tensor) -> return output +def fused_rotary_pos_emb_vision( + q: paddle.Tensor, k: paddle.Tensor, freqs: paddle.Tensor +) -> tuple[paddle.Tensor, paddle.Tensor]: + orig_dtype = q.dtype + + with paddle.amp.auto_cast(False): + q = q.astype(dtype="float32") + k = k.astype(dtype="float32") + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).tile(repeat_times=[1, 1, 2]).unsqueeze(0).astype(dtype="float32") + sin = sin.unsqueeze(1).tile(repeat_times=[1, 1, 2]).unsqueeze(0).astype(dtype="float32") + if fused_rotary_position_embedding is not None: + output_q, output_k, _ = fused_rotary_position_embedding( + q, k, sin=sin, cos=cos, use_neox_rotary_style=False + ) + else: + output_q = q * cos + rotate_half(q) * sin + output_k = k * cos + rotate_half(k) * sin + + output_q = paddle.cast(output_q, orig_dtype).squeeze(axis=0) + output_k = paddle.cast(output_k, orig_dtype).squeeze(axis=0) + return output_q, output_k + + class VisionRotaryEmbedding(nn.Layer): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() @@ -457,9 +496,9 @@ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> N self.hidden_size = context_dim * (spatial_merge_size**2) self.ln_q = nn.LayerNorm(context_dim, epsilon=1e-6) self.mlp = nn.Sequential( - nn.Linear(self.hidden_size, self.hidden_size), + linear_utils.Linear(self.hidden_size, self.hidden_size), nn.GELU(), - nn.Linear(self.hidden_size, dim), + linear_utils.Linear(self.hidden_size, dim), ) def forward(self, x: paddle.Tensor) -> paddle.Tensor: @@ -470,9 +509,9 @@ def forward(self, x: paddle.Tensor) -> paddle.Tensor: class VisionMlp(nn.Layer): def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: super().__init__() - self.fc1 = nn.Linear(dim, hidden_dim) + self.fc1 = linear_utils.Linear(dim, hidden_dim) self.act = ACT2FN[hidden_act] - self.fc2 = nn.Linear(hidden_dim, dim) + self.fc2 = linear_utils.Linear(hidden_dim, dim) def forward(self, x) -> paddle.Tensor: return self.fc2(self.act(self.fc1(x))) @@ -482,8 +521,8 @@ class VisionAttention(nn.Layer): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads - self.qkv = nn.Linear(dim, dim * 3, bias_attr=True) - self.proj = nn.Linear(dim, dim) + self.qkv = linear_utils.Linear(dim, dim * 3, bias_attr=True) + self.proj = linear_utils.Linear(dim, dim) self.head_dim = dim // num_heads # must added def forward( @@ -521,8 +560,8 @@ class VisionFlashAttention2(nn.Layer): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads - self.qkv = nn.Linear(dim, dim * 3, bias_attr=True) - self.proj = nn.Linear(dim, dim) + self.qkv = linear_utils.Linear(dim, dim * 3, bias_attr=True) + self.proj = linear_utils.Linear(dim, dim) self.head_dim = dim // num_heads # must added def forward( @@ -531,8 +570,11 @@ def forward( seq_length = tuple(hidden_states.shape)[0] qkv = self.qkv(hidden_states).reshape([seq_length, 3, self.num_heads, -1]).transpose(perm=[1, 0, 2, 3]) q, k, v = qkv.unbind(axis=0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0) + if paddle.is_compiled_with_xpu() and os.getenv("XPU_FUSE_ROPE"): + q, k = fused_rotary_pos_emb_vision(q.unsqueeze(axis=0), k.unsqueeze(axis=0), rotary_pos_emb) + else: + q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0) if _IS_NPU: attn_output = paddle.nn.functional.flash_attention_npu( # TODO: flash_attn_unpadded @@ -650,6 +692,15 @@ def __init__(self, config: Qwen2VLConfig, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): + if paddle.is_compiled_with_xpu() and os.getenv("XPU_FUSE_RMSNorm"): + try: + import paddle_xpu_nn # noqa: F821 + + return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] + except ImportError: + raise NotImplementedError( + f"Implementation of fused_rms_norm is not available on xpu. Please install paddle_xpu to use this feature" + ) if paddle.in_dynamic_mode(): with paddle.amp.auto_cast(False): variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) @@ -672,14 +723,12 @@ def __init__(self, config): self.fuse_attention_ffn = config.fuse_attention_ffn self.tensor_parallel_degree = config.tensor_parallel_degree - # else: ColumnParallelLinear = linear_utils.ColumnParallelLinear RowParallelLinear = linear_utils.RowParallelLinear - if config.tensor_parallel_degree > 1: - + self.gate_proj = ColumnParallelLinear( self.hidden_size, self.intermediate_size, @@ -699,20 +748,28 @@ def __init__(self, config): has_bias=False, ) else: - self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) # w1 - self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) # w3 - self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False) # w2 - + if paddle.is_compiled_with_xpu() and os.getenv("XPU_FUSE_FFN"): + self.gate_up_fused_proj = linear_utils.Linear( + self.hidden_size, self.intermediate_size * 2, bias_attr=False + ) + else: + self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) # w1 + self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) # w3 + self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False) # w2 + self.act_fn = ACT2FN[config.hidden_act] - self.fuse_swiglu = False + self.fuse_swiglu = False def forward(self, x): - x, y = self.gate_proj(x), self.up_proj(x) - if self.fuse_swiglu: - x = self.act_fn(x, y) + if paddle.is_compiled_with_xpu() and os.getenv("XPU_FUSE_FFN"): + x = self.gate_up_fused_proj(x) + x = paddle.incubate.nn.functional.swiglu(x) else: - x = self.act_fn(x) * y - + x, y = self.gate_proj(x), self.up_proj(x) + if self.fuse_swiglu: + x = self.act_fn(x, y) + else: + x = self.act_fn(x) * y return self.down_proj(x) @@ -768,23 +825,28 @@ def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None): self.num_key_value_heads % config.tensor_parallel_degree == 0 ), f"num_key_value_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree - + ColumnParallelLinear = linear_utils.ColumnParallelLinear RowParallelLinear = linear_utils.RowParallelLinear if config.tensor_parallel_degree > 1: - self.q_proj = ColumnParallelLinear( - self.hidden_size, self.hidden_size, has_bias=True, gather_output=False - ) + self.q_proj = ColumnParallelLinear(self.hidden_size, self.hidden_size, has_bias=True, gather_output=False) self.k_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip self.v_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip self.o_proj = RowParallelLinear(self.hidden_size, self.hidden_size, has_bias=False, input_is_parallel=True) else: - self.q_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=True) - self.k_proj = Linear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=True) - self.v_proj = Linear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=True) + if paddle.is_compiled_with_xpu() and os.getenv("XPU_FUSE_ATTN_QKV"): + self.qkv_proj = linear_utils.Linear( + self.hidden_size, + (self.num_heads * self.head_dim + self.num_key_value_heads * self.head_dim * 2), + bias_attr=True, + ) + else: + self.q_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=True) + self.k_proj = Linear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=True) + self.v_proj = Linear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=True) self.o_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=False) - + self.rotary_emb = Qwen2VLRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, @@ -812,14 +874,12 @@ def forward( query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - - + target_query_shape = [0, 0, self.num_heads, self.head_dim] target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim] query_states = query_states.reshape(shape=target_query_shape) key_states = key_states.reshape(shape=target_key_value_shape) value_states = value_states.reshape(shape=target_key_value_shape) - new_perm = [0, 2, 1, 3] query_states = query_states.transpose(new_perm) @@ -898,22 +958,33 @@ def forward( ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: bsz, q_len, _ = tuple(hidden_states.shape) - try: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - except: - hidden_states = hidden_states.astype("bfloat16") - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if paddle.is_compiled_with_xpu() and os.getenv("XPU_FUSE_ATTN_QKV"): + mix_layer = self.qkv_proj(hidden_states) + query_states, key_states, value_states = paddle.split( + mix_layer, + num_or_sections=[ + self.num_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ], + axis=-1, + ) + else: + try: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + except: + hidden_states = hidden_states.astype("bfloat16") + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) target_query_shape = [0, 0, self.num_heads, self.head_dim] target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim] query_states = query_states.reshape(shape=target_query_shape) key_states = key_states.reshape(shape=target_key_value_shape) value_states = value_states.reshape(shape=target_key_value_shape) - new_perm = [0, 2, 1, 3] # [1, 3599, 1536] [bsz, q_len, self.num_heads * self.head_dim] @@ -1189,7 +1260,7 @@ class Qwen2VLPreTrainedModel(PretrainedModel): def _init_weights(self, layer): std = 0.2 - if isinstance(layer, (nn.Linear, nn.Conv3D)): + if isinstance(layer, (linear_utils.Linear, nn.Conv3D)): nn.initializer.Normal(mean=0.0, std=std)(layer.weight) if layer.bias is not None: nn.initializer.Constant(0.0)(layer.bias) @@ -1488,6 +1559,10 @@ def __init__(self, config, embedding_weights=None, transpose_y=False): if self.weight.is_distributed: # for tie_word_embeddings self.weight.split_axis = 0 if self.transpose_y else 1 + if get_env_device() == "xpu": + self.matmul = xpu_matmul() + else: + self.matmul = None def forward(self, hidden_states, tensor_parallel_output=None): if tensor_parallel_output is None: @@ -1497,9 +1572,12 @@ def forward(self, hidden_states, tensor_parallel_output=None): if self.weight.dtype != hidden_states.dtype: hidden_states = paddle.cast(hidden_states, self.weight.dtype) - logits = parallel_matmul( - hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output - ) + if get_env_device() == "xpu": + logits = self.matmul(hidden_states, self.weight, transpose_y=self.transpose_y, training=self.training) + else: + logits = parallel_matmul( + hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output + ) return logits @@ -1912,7 +1990,7 @@ def forward( tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1 logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output) - + logits = paddle.cast(logits, "float32") loss = None From c5ae3197a168735cc9bdeba0383c1c494bea9262 Mon Sep 17 00:00:00 2001 From: zhangyikun02 <1129622649@qq.com> Date: Thu, 23 Jan 2025 14:52:06 +0800 Subject: [PATCH 2/2] [xpu]add script to deal safetensor file for fuse op --- .../qwen2_vl/shell/baseline_7b_bs32_1e8.sh | 4 +- .../models/qwen2_vl/modeling_qwen2_vl.py | 2 +- scripts/conver_safetensor.py | 60 +++++++++++++++++++ 3 files changed, 64 insertions(+), 2 deletions(-) create mode 100644 scripts/conver_safetensor.py diff --git a/paddlemix/examples/qwen2_vl/shell/baseline_7b_bs32_1e8.sh b/paddlemix/examples/qwen2_vl/shell/baseline_7b_bs32_1e8.sh index f53aa8ab3..b430bb6dc 100644 --- a/paddlemix/examples/qwen2_vl/shell/baseline_7b_bs32_1e8.sh +++ b/paddlemix/examples/qwen2_vl/shell/baseline_7b_bs32_1e8.sh @@ -36,7 +36,9 @@ TRAINING_MODEL_RESUME="None" TRAINER_INSTANCES='127.0.0.1' MASTER='127.0.0.1:8080' -meta_path="paddlemix/examples/qwen2_vl/configs/baseline_6data_330k.json" +# meta_path="paddlemix/examples/qwen2_vl/configs/baseline_6data_330k.json" +meta_path="paddlemix/examples/qwen2_vl/configs/demo_chartqa_500.json" +export PYTHONPATH="/path/to/PaddleMIX/ppdiffusers:/path/to/PaddleNLP:${PYTHONPATH}" ### XPU ### export XPU_CDNN_CLUSTER_PARALLEL=1 diff --git a/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py b/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py index 930135ee7..d51992421 100644 --- a/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py +++ b/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py @@ -52,7 +52,7 @@ from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig if get_env_device() == "xpu": - from paddle_xpu.layers.linear_utils.Linear import xpu_matmul + from paddle_xpu.layers.nn.linear import xpu_matmul else: xpu_matmul = None try: diff --git a/scripts/conver_safetensor.py b/scripts/conver_safetensor.py new file mode 100644 index 000000000..885adc615 --- /dev/null +++ b/scripts/conver_safetensor.py @@ -0,0 +1,60 @@ +import paddle +import safetensors.paddle +import numpy as np + +new_safetensors = {} +metadata = {"total_size": "16582751232",} +layer11_gate_weight = None +layer11_up_weight = None +for idx in range(1, 6): + file_path = "/path/to/Qwen2-VL-7B-Instruct/model-0000" + str(idx) + "-of-00005.safetensors" + new_file_path="/new_path/to/Qwen2-VL-7B-Instruct-fuse_qkv/model-0000" + str(idx) + "-of-00005.safetensors" + theta = ( + safetensors.paddle.load_file(file_path) + ) + for key, val in theta.items(): + # print("key = ", key, " val.shape = ", val.shape) + if len(key.split('.')) == 6 and key.split('.', 4)[4] == 'q_proj.weight': + q_weight = val + k_weight = theta[key.replace('q_proj', 'k_proj')] + v_weight = theta[key.replace('q_proj', 'v_proj')] + qkv_weight = paddle.concat([q_weight, k_weight, v_weight], axis=-1) + # print(qkv_weight.shape) + new_safetensors[key.replace('q_proj', 'qkv_proj')] = qkv_weight + elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'q_proj.bias': + q_bias = val + k_bias = theta[key.replace('q_proj', 'k_proj')] + v_bias = theta[key.replace('q_proj', 'v_proj')] + qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1) + # print(qkv_bias.shape) + new_safetensors[key.replace('q_proj', 'qkv_proj')] = qkv_bias + elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'k_proj.weight': + continue + elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'k_proj.bias': + continue + elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'v_proj.weight': + continue + elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'v_proj.bias': + continue + elif len(key.split('.')) == 6 and key.split('.', 2)[2] == '11.mlp.up_proj.weight': + layer11_up_weight = val + elif len(key.split('.')) == 6 and key.split('.', 2)[2] == '11.mlp.gate_proj.weight': + layer11_gate_weight = val + gate_up_weight = paddle.concat([layer11_gate_weight, layer11_up_weight], axis=-1) + new_safetensors[key.replace('gate_proj', 'gate_up_fused_proj')] = gate_up_weight + elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'gate_proj.weight': + gate_weight = val + up_weight = theta[key.replace('gate_proj', 'up_proj')] + gate_up_weight = paddle.concat([gate_weight, up_weight], axis=-1) + new_safetensors[key.replace('gate_proj', 'gate_up_fused_proj')] = gate_up_weight + elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'up_proj.weight': + continue + else: + new_safetensors[key] = val + # save new safetensors + safetensors.paddle.save_file(new_safetensors, new_file_path, metadata=metadata) + print("save new safetensors for ", new_file_path) + new_safetensors.clear() + + +