Skip to content

[XPU] optimize Qwen2_vl for xpu #1020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions paddlemix/examples/qwen2_vl/qwen2vl_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import sys
import traceback
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, Any
from typing import Any, Dict, Optional, Sequence

import numpy as np
import paddle
Expand All @@ -31,6 +31,7 @@
from paddlenlp.trainer import PdArgumentParser, TrainingArguments, set_seed
from paddlenlp.trainer.trainer import Trainer
from paddlenlp.trainer.trainer_utils import get_last_checkpoint
from paddlenlp.transformers.processing_utils import ProcessorMixin
from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError

from paddlemix.datasets.internvl_dataset import ConcatDataset, WeightedConcatDataset
Expand All @@ -42,7 +43,6 @@
Qwen2VLImageProcessor,
Qwen2VLProcessor,
)
from paddlenlp.transformers.processing_utils import ProcessorMixin

Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True
Expand Down Expand Up @@ -355,7 +355,7 @@ def pure_text_get_item(self, data_item):
attention_mask=attention_mask,
images=[],
)

return ret

def __getitem__(self, i) -> Dict[str, paddle.Tensor]:
Expand Down Expand Up @@ -460,7 +460,7 @@ def __post_init__(self):

def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []

for feature in features:
images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or []
Expand All @@ -470,17 +470,15 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens
batch_vidlens.append(len(videos))
batch_input_ids.append(feature["input_ids"])

if (
self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
):
if self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0:
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
fake_input_ids, None, fake_images, [], self.tokenizer, self.processor
)

if self.tokenizer.padding_side == "right":
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
Expand Down Expand Up @@ -530,7 +528,6 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens
return features



def main():
parser = PdArgumentParser((ModelArguments, DataTrainingArguments, PreTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
Expand Down Expand Up @@ -565,6 +562,16 @@ def main():
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)

if paddle.is_compiled_with_xpu() and training_args.gradient_accumulation_steps > 1:
try:
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401

LinearConfig.enable_accumulate_steps_opt()
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)
except ImportError:
# It's OK, not use accumulate_steps optimization
pass

# Load model
if "npu" in paddle.get_device():
is_bfloat16_supported = True
Expand Down
30 changes: 26 additions & 4 deletions paddlemix/examples/qwen2_vl/shell/baseline_7b_bs32_1e8.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand Down Expand Up @@ -36,7 +36,28 @@ TRAINING_MODEL_RESUME="None"
TRAINER_INSTANCES='127.0.0.1'
MASTER='127.0.0.1:8080'

meta_path="paddlemix/examples/qwen2_vl/configs/baseline_6data_330k.json"
# meta_path="paddlemix/examples/qwen2_vl/configs/baseline_6data_330k.json"
meta_path="paddlemix/examples/qwen2_vl/configs/demo_chartqa_500.json"
export PYTHONPATH="/path/to/PaddleMIX/ppdiffusers:/path/to/PaddleNLP:${PYTHONPATH}"

### XPU ###
export XPU_CDNN_CLUSTER_PARALLEL=1
export XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER=2
export XPU_PADDLE_FUSE_SHARDING_BUFFER=1
export FLAGS_use_stride_kernel="0"
# export XPU_PADDLE_L3_SIZE=98566144 # 94 MB
# export XBLAS_FC_AUTOTUNE_FILE="/zhangyikun02/PaddleMIX/autotune_qwen2_vl_7b"
export BKCL_TREE_THRESHOLD=0

export XPU_FUSE_RMSNorm=1
export XPU_FUSE_ATTN_QKV=1
export XPU_FUSE_FFN=1
export XPU_FUSE_ROPE=1
# export PRINT_TIMMER=1
# export PROFILER=1

# export XPUAPI_DEBUG=1
# export XPURT_DISPATCH_MODE=PROFILING

TRAINING_PYTHON="python -m paddle.distributed.launch --master ${MASTER} --nnodes 1 --nproc_per_node ${GPUS} --rank 0 --ips ${TRAINER_INSTANCES} --run_mode=collective"
${TRAINING_PYTHON} --log_dir ${OUTPUT_DIR}/paddle_distributed_logs \
Expand Down Expand Up @@ -73,6 +94,7 @@ ${TRAINING_PYTHON} --log_dir ${OUTPUT_DIR}/paddle_distributed_logs \
--report_to "visualdl" \
--tensor_parallel_degree=${tensor_parallel_degree} \
--sharding_parallel_degree=${sharding_parallel_degree} \
--sharding_parallel_config "split_param" \
--pipeline_parallel_degree=1 \
--sep_parallel_degree=1 \
--sharding="stage1" \
Expand Down
2 changes: 1 addition & 1 deletion paddlemix/models/qwen2_vl/configuration_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
depth=32,
embed_dim=1280,
hidden_size=3584,
hidden_act="quick_gelu",
hidden_act="gelu",
mlp_ratio=4,
num_heads=16,
in_channels=3,
Expand Down
Loading