From e87af21963244678e4b03c402a77fba91cae6abc Mon Sep 17 00:00:00 2001
From: Pioneer-wxh <2274246074@qq.com>
Date: Fri, 26 Sep 2025 21:43:29 +0800
Subject: [PATCH 1/4] commit dislora
---
.gitignore | 3 +-
docs/zh/llm/benchmark/rl/README.md | 1 +
llm/config/llama/dislora_argument.json | 37 ++
llm/config/qwen/dislora_argument.json | 36 ++
llm/run_finetune.py | 89 ++++-
llm/tools/merge_dislora_params.py | 290 +++++++++++++++
paddlenlp/peft/__init__.py | 1 +
paddlenlp/peft/dislora/__init__.py | 19 +
paddlenlp/peft/dislora/dislora_config.py | 160 ++++++++
paddlenlp/peft/dislora/dislora_layer.py | 312 ++++++++++++++++
paddlenlp/peft/dislora/dislora_model.py | 446 +++++++++++++++++++++++
paddlenlp/trainer/trainer.py | 15 +-
paddlenlp/trl/model_config.py | 23 +-
paddlenlp/trl/sft_config.py | 4 +
paddlenlp/trl/sft_trainer.py | 67 ++++
paddlenlp/utils/env.py | 3 +
tests/fixtures/llm/dislora.yaml | 78 ++++
tests/llm/test_dislora.py | 82 +++++
tests/peft/test_dislora.py | 232 ++++++++++++
19 files changed, 1893 insertions(+), 5 deletions(-)
create mode 120000 docs/zh/llm/benchmark/rl/README.md
create mode 100644 llm/config/llama/dislora_argument.json
create mode 100644 llm/config/qwen/dislora_argument.json
create mode 100644 llm/tools/merge_dislora_params.py
create mode 100644 paddlenlp/peft/dislora/__init__.py
create mode 100644 paddlenlp/peft/dislora/dislora_config.py
create mode 100644 paddlenlp/peft/dislora/dislora_layer.py
create mode 100644 paddlenlp/peft/dislora/dislora_model.py
create mode 100644 tests/fixtures/llm/dislora.yaml
create mode 100644 tests/llm/test_dislora.py
create mode 100644 tests/peft/test_dislora.py
diff --git a/.gitignore b/.gitignore
index 5f29895be914..47510447a842 100644
--- a/.gitignore
+++ b/.gitignore
@@ -140,4 +140,5 @@ autogen/
#fp8
ops/csrc/fp8/deep_gemm/include/cutlass
ops/csrc/fp8/deep_gemm/include/cute
-.ccls-cache
\ No newline at end of file
+.ccls-cache
+llm/log
diff --git a/docs/zh/llm/benchmark/rl/README.md b/docs/zh/llm/benchmark/rl/README.md
new file mode 120000
index 000000000000..c8ff8b971399
--- /dev/null
+++ b/docs/zh/llm/benchmark/rl/README.md
@@ -0,0 +1 @@
+../../../../../llm/benchmark/rl/README.md
\ No newline at end of file
diff --git a/llm/config/llama/dislora_argument.json b/llm/config/llama/dislora_argument.json
new file mode 100644
index 000000000000..35dc6ed7bcd8
--- /dev/null
+++ b/llm/config/llama/dislora_argument.json
@@ -0,0 +1,37 @@
+{
+ "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
+ "dataset_name_or_path": "/home/bjh/bjh/Dislora/cs_5_lite",
+ "output_dir": "./checkpoints/dislora_ckpts_3",
+ "dislora": true,
+ "per_device_train_batch_size": 1,
+ "gradient_accumulation_steps": 5,
+ "num_train_epochs": 1,
+ "learning_rate": 2e-05,
+ "lr_scheduler_type": "linear",
+ "warmup_steps": 30,
+ "logging_steps": 1,
+ "evaluation_strategy": "no",
+ "save_strategy": "steps",
+ "save_steps": 500,
+ "src_length": 256,
+ "max_length": 512,
+ "bf16": true,
+ "do_train": true,
+ "do_eval": false,
+ "disable_tqdm": false,
+ "load_best_model_at_end": false,
+ "eval_with_do_generation": false,
+ "recompute": false,
+ "save_total_limit": 5,
+ "fp16_opt_level": "O2",
+ "sharding": "stage3",
+ "zero_padding": false,
+ "use_flash_attention": false,
+ "unified_checkpoint": false,
+ "dislora_rank": 8,
+ "dislora_dropout": 0.05,
+ "target_modules": [".*q_proj.*", ".*v_proj.*", ".*k_proj.*", ".*o_proj.*"],
+ "s_tsd": 8,
+ "ortho_lambda": 1.0,
+ "prefer_small_sigma": true
+}
\ No newline at end of file
diff --git a/llm/config/qwen/dislora_argument.json b/llm/config/qwen/dislora_argument.json
new file mode 100644
index 000000000000..f1383adaa163
--- /dev/null
+++ b/llm/config/qwen/dislora_argument.json
@@ -0,0 +1,36 @@
+{
+ "model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
+ "dataset_name_or_path": "/home/bjh/bjh/Dislora/cs_5_lite",
+ "output_dir": "./checkpoints/dislora_ckpts",
+ "dislora": true,
+ "per_device_train_batch_size": 1,
+ "gradient_accumulation_steps": 1,
+ "num_train_epochs": 1,
+ "learning_rate": 2e-05,
+ "lr_scheduler_type": "linear",
+ "warmup_steps": 30,
+ "logging_steps": 1,
+ "evaluation_strategy": "no",
+ "save_strategy": "steps",
+ "save_steps": 500,
+ "src_length": 256,
+ "max_length": 512,
+ "bf16": true,
+ "do_train": true,
+ "do_eval": false,
+ "disable_tqdm": false,
+ "load_best_model_at_end": false,
+ "eval_with_do_generation": false,
+ "recompute": false,
+ "save_total_limit": 5,
+ "fp16_opt_level": "O2",
+ "sharding": "stage3",
+ "zero_padding": false,
+ "use_flash_attention": false,
+ "unified_checkpoint": false,
+ "dislora_rank": 8,
+ "dislora_dropout": 0.05,
+ "s_tsd": 8,
+ "ortho_lambda": 1.0,
+ "prefer_small_sigma": true
+}
\ No newline at end of file
diff --git a/llm/run_finetune.py b/llm/run_finetune.py
index 31427a516f2d..37afa9e4528f 100644
--- a/llm/run_finetune.py
+++ b/llm/run_finetune.py
@@ -31,6 +31,8 @@
)
from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL
from paddlenlp.peft import (
+ DisLoRAConfig,
+ DisLoRAModel,
LoKrConfig,
LoKrModel,
LoRAConfig,
@@ -311,6 +313,15 @@ def neft_post_hook(module, input, output):
tokenizer.pad_token_id = tokenizer.eos_token_id
train_ds, dev_ds, test_ds = create_dataset(data_args, training_args)
+
+ train_dataset_size = None
+ if train_ds is not None and model_args.dislora:
+ train_dataset_size = get_dataset_size(train_ds)
+ if train_dataset_size is not None:
+ logger.info(f"Original training dataset size: {train_dataset_size}")
+ else:
+ logger.warning("Unable to determine training dataset size for dynamic dash_flag calculation")
+
# TODO(ZHUI & sijunhe): Temporary implementation. Generalize this logic and move to Trainer later.
if training_args.resume_from_checkpoint is not None and data_args.lazy:
logger.info(
@@ -377,7 +388,9 @@ def neft_post_hook(module, input, output):
if eval_zero_padding and test_ds is not None:
test_ds = intoken_dataset(test_ds, tokenizer=tokenizer, max_length=data_args.max_length)
- model = create_peft_model(model_args, reft_args, training_args, dtype, model_config, model, reft_layers)
+ model = create_peft_model(
+ model_args, reft_args, training_args, dtype, model_config, model, reft_layers, train_dataset_size
+ )
def compute_metrics_do_generation(eval_preds):
rouge1 = Rouge1()
@@ -441,6 +454,10 @@ def compute_metrics_do_generation(eval_preds):
return_attention_mask=not model_args.flash_mask,
pad_to_multiple_of=data_args.pad_to_multiple_of,
)
+
+ if model_args.dislora and hasattr(model_args, "ortho_lambda"):
+ training_args.dislora_ortho_lambda = model_args.ortho_lambda
+
trainer = SFTTrainer(
model=model,
args=training_args,
@@ -531,7 +548,9 @@ def save_to_aistudio(model_args, training_args, trainer):
)
-def create_peft_model(model_args, reft_args, training_args, dtype, model_config, model, reft_layers):
+def create_peft_model(
+ model_args, reft_args, training_args, dtype, model_config, model, reft_layers, train_dataset_size
+):
if model_args.prefix_tuning:
if training_args.pipeline_parallel_degree > 1:
raise NotImplementedError("Prefix tuning is not implemented for pipeline parallelism.")
@@ -606,6 +625,53 @@ def create_peft_model(model_args, reft_args, training_args, dtype, model_config,
else:
model = LoKrModel.from_pretrained(model=model, lokr_path=model_args.lokr_path)
+ if model_args.dislora:
+ # Calculate dynamic dash_flag based on training configuration
+ if train_dataset_size is not None and training_args.do_train:
+ # Calculate warmup steps: len(train_data) * num_epochs // (batch_size * gradient_accumulation_steps * 3)
+ effective_batch_size = (
+ training_args.per_device_train_batch_size
+ * training_args.gradient_accumulation_steps
+ * training_args.dataset_world_size # Consider data parallel
+ )
+ calculated_dash_flag = (train_dataset_size * training_args.num_train_epochs) // (effective_batch_size * 3)
+
+ # Use calculated value if it's reasonable, otherwise fall back to model_args
+ if calculated_dash_flag > 0:
+ dash_flag = calculated_dash_flag
+ logger.info(
+ f"Calculated dynamic dash_flag: {dash_flag} based on dataset size: {train_dataset_size}, "
+ f"epochs: {training_args.num_train_epochs}, effective batch size: {effective_batch_size}"
+ )
+ else:
+ dash_flag = model_args.dash_flag
+ logger.warning(
+ f"Calculated dash_flag was {calculated_dash_flag}, using model_args.dash_flag: {dash_flag}"
+ )
+ else:
+ dash_flag = getattr(model_args, "dash_flag", 50)
+ if train_dataset_size is None:
+ logger.info(
+ f"Unable to calculate dynamic dash_flag (dataset size unknown), using configured dash_flag: {dash_flag}"
+ )
+ else:
+ logger.info(f"Not in training mode, using configured dash_flag: {dash_flag}")
+ if model_args.dislora_path is None:
+ dislora_config = DisLoRAConfig(
+ target_modules=model_args.target_modules
+ if model_args.target_modules
+ else get_lora_target_modules(model),
+ r=model_args.dislora_rank,
+ dislora_alpha=1.5 * model_args.dislora_rank,
+ dislora_dropout=model_args.dislora_dropout,
+ dtype=dtype,
+ base_model_name_or_path=model_args.model_name_or_path,
+ s_tsd=model_args.s_tsd,
+ dash_flag=dash_flag, # Use calculated dash_flag
+ ortho_lambda=model_args.ortho_lambda,
+ )
+ model = DisLoRAModel(model, dislora_config)
+
if model_args.reft:
intervention_dtype = dtype
intervention_params = {
@@ -745,5 +811,24 @@ def create_dataset(data_args, training_args):
return train_ds, dev_ds, test_ds
+def get_dataset_size(dataset):
+ """Get the size of a dataset, handling both lazy and regular datasets"""
+ if dataset is None:
+ return None
+
+ try:
+ if hasattr(dataset, "__len__"):
+ return len(dataset)
+ elif hasattr(dataset, "_length"):
+ return dataset._length
+ else:
+ # For lazy datasets, we might need to iterate once to count
+ logger.warning("Unable to determine dataset size directly for lazy loading dataset")
+ return None
+ except Exception as e:
+ logger.warning(f"Error getting dataset size: {e}")
+ return None
+
+
if __name__ == "__main__":
main()
diff --git a/llm/tools/merge_dislora_params.py b/llm/tools/merge_dislora_params.py
new file mode 100644
index 000000000000..f393b3d4971a
--- /dev/null
+++ b/llm/tools/merge_dislora_params.py
@@ -0,0 +1,290 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+
+import paddle
+
+from paddlenlp.peft import DisLoRAConfig, DisLoRAModel
+from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
+from paddlenlp.utils.env import CONFIG_NAME
+
+
+def parse_arguments():
+ """解析命令行参数"""
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_name_or_path", default=None, help="The directory of pretrained model.")
+ parser.add_argument("--dislora_path", default="", help="The directory of dislora parameters. Default to None")
+ parser.add_argument(
+ "--merge_dislora_model_path",
+ default="",
+ help="The directory of merged parameters. Default to None",
+ )
+ parser.add_argument("--device", type=str, default="gpu", help="Device")
+ parser.add_argument(
+ "--low_gpu_mem", type=bool, default=True, help="Whether to use low gpu memory. Default to False"
+ )
+ return parser.parse_args()
+
+
+def weight_process(name, dislora_config, state_dict):
+ """
+ Based on the DisLoRA algorithm for processing weight merging:
+ The final weight = W_prin + W_res + W_TSD
+ However, here we do not directly add the adapter to the base model; instead, we reconstruct the entire weight matrix.
+ Args:
+ name: Layer name (e.g. "model.layers.0.self_attn.q_proj")
+ dislora_config: DisLoRA configuration
+ state_dict: Model state dictionary
+ # Define the weight_process function to handle the DisLoRA weight merging. The parameters include the layer name, DisLoRA configuration, and the model state dictionary.
+ """
+
+ weight_key = name + ".weight"
+
+ if weight_key not in state_dict:
+ print(f"Warning: {weight_key} not found in state_dict")
+ return
+
+ w_prin = state_dict[weight_key]
+ print(f"Processing layer: {name}")
+ print(f" W_prin shape: {w_prin.shape}")
+
+ scaling = dislora_config.dislora_alpha / dislora_config.r
+
+ final_weight = w_prin.clone()
+
+ ur_key = name + ".Direc_Ur.weight"
+ sr_key = name + ".Direc_Sr"
+ vhr_key = name + ".Direc_Vhr.weight"
+
+ w_res_added = False
+
+ if all(key in state_dict for key in [ur_key, sr_key, vhr_key]):
+
+ direc_ur = state_dict[ur_key] # [r, out_features]
+ direc_sr = state_dict[sr_key] # [r]
+ direc_vhr = state_dict[vhr_key] # [in_features, r]
+
+ s_diag = paddle.diag(direc_sr) # [r, r]
+
+ w_res = direc_vhr @ s_diag @ direc_ur * scaling # [in_features, out_features]
+
+ if w_res.shape != w_prin.shape:
+ print(f" Error: W_res shape {w_res.shape} doesn't match W_prin shape {w_prin.shape}")
+ return
+
+ final_weight += w_res
+ w_res_added = True
+ print(f" ✓ Added W_res with scaling factor: {scaling}")
+ else:
+ print(f" ⚠ W_res components not found for {name}")
+
+ utsd_key = name + ".Direc_Utsd.weight"
+ stsd_key = name + ".Direc_Stsd"
+ vhtsd_key = name + ".Direc_Vhtsd.weight"
+
+ w_tsd_added = False
+ if all(key in state_dict for key in [utsd_key, stsd_key, vhtsd_key]):
+
+ direc_utsd = state_dict[utsd_key] # [s_tsd, out_features]
+ direc_stsd = state_dict[stsd_key] # [s_tsd]
+ direc_vhtsd = state_dict[vhtsd_key] # [in_features, s_tsd]
+
+ if not paddle.all(direc_stsd == 0.0):
+
+ s_diag_tsd = paddle.diag(direc_stsd) # [s_tsd, s_tsd]
+
+ w_tsd = direc_vhtsd @ s_diag_tsd @ direc_utsd * scaling # [in_features, out_features]
+
+ if w_tsd.shape != w_prin.shape:
+ print(f" Error: W_TSD shape {w_tsd.shape} doesn't match W_prin shape {w_prin.shape}")
+ return
+
+ final_weight += w_tsd
+ w_tsd_added = True
+ print(f" ✓ Added W_TSD with scaling factor: {scaling}")
+ else:
+ print(f" ⚠ W_TSD parameters are uninitialized (all zeros) for {name}")
+ else:
+ print(f" ⚠ W_TSD components not found for {name}")
+
+ state_dict[weight_key] = final_weight
+
+ keys_to_remove = []
+ for key in state_dict.keys():
+ if key.startswith(name + ".Direc_") or key == name + ".step":
+ keys_to_remove.append(key)
+
+ for key in keys_to_remove:
+ removed_param = state_dict.pop(key)
+ print(f" ✓ Removed DisLoRA parameter: {key} (shape: {removed_param.shape})")
+
+ components = []
+ if w_res_added:
+ components.append("W_res")
+ if w_tsd_added:
+ components.append("W_TSD")
+
+ if components:
+ print(f" ✓ Successfully merged: W_prin + {' + '.join(components)}")
+ else:
+ print(" ✓ Kept original W_prin (no adaptations found)")
+ print()
+
+
+def merge():
+
+ args = parse_arguments()
+ paddle.set_device(args.device)
+
+ print("Loading DisLoRA configuration...")
+ dislora_config = DisLoRAConfig.from_pretrained(args.dislora_path)
+ if dislora_config.base_model_name_or_path is None:
+ if args.model_name_or_path is None:
+ raise ValueError("We can not find a valid model_name_or_path.")
+ else:
+ dislora_config.base_model_name_or_path = args.model_name_or_path
+
+ print("Loading model configuration...")
+ if os.path.isfile(os.path.join(args.dislora_path, CONFIG_NAME)):
+ config = AutoConfig.from_pretrained(args.dislora_path)
+ elif args.model_name_or_path is not None:
+ config = AutoConfig.from_pretrained(args.model_name_or_path)
+ else:
+ raise ValueError(
+ f"We can not find config.json in dislora_path: {args.dislora_path} or find a valid model_name_or_path."
+ )
+
+ config.dtype = dislora_config.dtype
+
+ if (
+ dislora_config.dtype == "bfloat16"
+ or (
+ hasattr(config, "quantization_config")
+ and hasattr(config.quantization_config, "weight_quantize_algo")
+ and config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
+ )
+ ) and args.device == "cpu":
+ raise ValueError("We can not apply bfloat16 or nf4/fp4 dislora merge on cpu.")
+
+ print("Loading base model...")
+ model = AutoModelForCausalLM.from_pretrained(
+ dislora_config.base_model_name_or_path,
+ config=config,
+ low_cpu_mem_usage=args.low_gpu_mem,
+ )
+
+ print("Loading DisLoRA model...")
+ model = DisLoRAModel.from_pretrained(model=model, dislora_path=args.dislora_path, dislora_config=dislora_config)
+
+ model.eval()
+ model_state_dict = model.model.state_dict()
+
+ print(f"Total parameters in state_dict: {len(model_state_dict)}")
+
+ step_keys = [key for key in model_state_dict.keys() if key.endswith(".step")]
+ if step_keys:
+ print(f"Found {len(step_keys)} step parameters in loaded model:")
+ for key in step_keys[:5]:
+ print(f" {key}")
+ if len(step_keys) > 5:
+ print(f" ... and {len(step_keys) - 5} more")
+ else:
+ print("No step parameters found in loaded model")
+ print()
+
+ print("Identifying DisLoRA layers...")
+ dislora_name_set = set()
+ for key in model_state_dict.keys():
+ if any(
+ dislora_param in key
+ for dislora_param in ["Direc_Ur", "Direc_Sr", "Direc_Vhr", "Direc_Utsd", "Direc_Stsd", "Direc_Vhtsd"]
+ ):
+
+ for param_type in ["Direc_Ur", "Direc_Sr", "Direc_Vhr", "Direc_Utsd", "Direc_Stsd", "Direc_Vhtsd"]:
+ if f".{param_type}" in key:
+ layer_name = key.split(f".{param_type}")[0]
+ dislora_name_set.add(layer_name)
+ break
+
+ dislora_name_list = sorted(list(dislora_name_set))
+
+ print(f"Found {len(dislora_name_list)} DisLoRA layers:")
+ for i, name in enumerate(dislora_name_list, 1):
+ print(f" {i:2d}. {name}")
+ print()
+
+ print("Merging DisLoRA parameters...")
+
+ for i, name in enumerate(dislora_name_list, 1):
+ print(f"[{i}/{len(dislora_name_list)}] Processing: {name}")
+ weight_process(name, dislora_config, model_state_dict)
+
+ print("Cleaning up remaining step parameters...")
+ step_keys_to_remove = [key for key in model_state_dict.keys() if key.endswith(".step")]
+ for key in step_keys_to_remove:
+ removed_param = model_state_dict.pop(key)
+ print(f" ✓ Removed step parameter: {key} (shape: {removed_param.shape})")
+
+ if step_keys_to_remove:
+ print(f"✓ Removed {len(step_keys_to_remove)} step parameters")
+ else:
+ print("✓ No step parameters found")
+ print()
+
+ print("Verifying parameter cleanup...")
+ remaining_dislora_params = []
+ remaining_step_params = []
+ for key in model_state_dict.keys():
+ if any(
+ dislora_param in key
+ for dislora_param in ["Direc_Ur", "Direc_Sr", "Direc_Vhr", "Direc_Utsd", "Direc_Stsd", "Direc_Vhtsd"]
+ ):
+ remaining_dislora_params.append(key)
+ if key.endswith(".step"):
+ remaining_step_params.append(key)
+
+ if remaining_dislora_params:
+ print(f"Warning: {len(remaining_dislora_params)} DisLoRA parameters still remain:")
+ for param in remaining_dislora_params:
+ print(f" - {param}")
+ else:
+ print("✓ All DisLoRA parameters successfully removed")
+
+ if remaining_step_params:
+ print(f"Warning: {len(remaining_step_params)} step parameters still remain:")
+ for param in remaining_step_params:
+ print(f" - {param}")
+ else:
+ print("✓ All step parameters successfully removed")
+ print()
+
+ print("Saving merged model...")
+ os.makedirs(args.merge_dislora_model_path, exist_ok=True)
+ model.model.save_pretrained(args.merge_dislora_model_path, state_dict=model_state_dict)
+
+ print("Saving tokenizer...")
+ tokenizer = AutoTokenizer.from_pretrained(dislora_config.base_model_name_or_path)
+ tokenizer.save_pretrained(args.merge_dislora_model_path)
+
+ print("=" * 80)
+ print("✓ DisLoRA merge completed successfully!")
+ print(f"✓ Merged model saved to: {args.merge_dislora_model_path}")
+ print(f"✓ Processed {len(dislora_name_list)} DisLoRA layers")
+ print("=" * 80)
+
+
+if __name__ == "__main__":
+ merge()
diff --git a/paddlenlp/peft/__init__.py b/paddlenlp/peft/__init__.py
index 85c61ffc793b..331488e56dc0 100644
--- a/paddlenlp/peft/__init__.py
+++ b/paddlenlp/peft/__init__.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from .dislora import DisLoRAConfig, DisLoRALinear, DisLoRAModel
from .lokr import LoKrConfig, LoKrModel
from .lora import LoRAAutoConfig, LoRAAutoModel, LoRAConfig, LoRAModel
from .prefix import PrefixConfig, PrefixModelForCausalLM
diff --git a/paddlenlp/peft/dislora/__init__.py b/paddlenlp/peft/dislora/__init__.py
new file mode 100644
index 000000000000..c1bdb6cd810e
--- /dev/null
+++ b/paddlenlp/peft/dislora/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .dislora_config import DisLoRAConfig
+from .dislora_layer import DisLoRALinear
+from .dislora_model import DisLoRAModel
+
+__all__ = ["DisLoRAConfig", "DisLoRAModel", "DisLoRALinear"]
diff --git a/paddlenlp/peft/dislora/dislora_config.py b/paddlenlp/peft/dislora/dislora_config.py
new file mode 100644
index 000000000000..b9ff8b1b36a7
--- /dev/null
+++ b/paddlenlp/peft/dislora/dislora_config.py
@@ -0,0 +1,160 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from dataclasses import asdict, dataclass, field
+from typing import List, Optional, Union
+
+from ...utils.env import DISLORA_CONFIG_NAME
+
+
+@dataclass
+class DisLoRAConfig:
+ """
+ This is the configuration class to store the configuration of a [`DisLoRAModel`].
+ Args:
+ target_modules (`Union[List[str],str]`): The names of the modules to apply DisLoRA to.
+ trainable_modules (`List[str]`): The names of the modules to train when applying DisLoRA.
+ dislora_alpha (`float`): The alpha parameter for DisLoRA scaling.
+ merge_weights (`bool`):
+ Whether to merge the weights of the DisLoRA layers with the base transfoisrmer model in `eval` mode.
+ """
+
+ base_model_name_or_path: Optional[str] = field(
+ default=None, metadata={"help": "The name of the base model to use."}
+ )
+ r: int = field(default=8, metadata={"help": "DisLoRA attention dimension"})
+ target_modules: Optional[Union[List[str], str]] = field(
+ default=None,
+ metadata={
+ "help": "List of module names or regex expression of the module names to replace with DisLoRA."
+ "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
+ },
+ )
+ trainable_modules: Optional[List[str]] = field(
+ default=None,
+ metadata={
+ "help": "List of module names or regex expression of the module names to train when applying with DisLoRA."
+ "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
+ },
+ )
+ dislora_alpha: int = field(default=12, metadata={"help": "DisLoRA alpha"})
+ dislora_dropout: float = field(default=0.0, metadata={"help": "DisLoRA dropout"})
+ merge_weights: bool = field(
+ default=False, metadata={"help": "Merge weights of the original model and the DisLoRA model"}
+ )
+ trainable_bias: Optional[str] = field(
+ default=None, metadata={"help": "Define trainable bias parameters for the DisLoRA model."}
+ )
+
+ tensor_parallel_degree: int = field(default=-1, metadata={"help": "1 for not use tensor parallel"})
+ dtype: Optional[str] = field(default=None, metadata={"help": "The data type of tensor"})
+
+ dash_flag: int = field( # characteristic
+ default=50,
+ metadata={"help": "The number of preheating steps before introducing additional low-rank updates"},
+ )
+
+ s_tsd: int = field( # characteristic
+ default=8,
+ metadata={"help": "The number of top-k singular vectors dynamically selected after preheating"},
+ )
+
+ ortho_lambda: float = field( # characteristic
+ default=1,
+ metadata={"help": "The weight of orthogonal regularization loss"},
+ )
+ prefer_small_sigma: bool = field(
+ default=True,
+ metadata={"help": "Whether to prioritize the smallest singular value in the top-k selection process"},
+ )
+
+ def __post_init__(self):
+
+ if self.target_modules is None:
+ raise ValueError("The target_modules must be specified as a string or a list of strings.")
+ if self.r <= 0:
+ raise ValueError("The rank r of LoRA must be greater than 0.")
+ if self.dislora_alpha <= 0:
+ raise ValueError("dislora_alpha must be greater than 0")
+ if self.r < self.s_tsd:
+ raise ValueError("The rank r of LoRA must be larger than the number of top-k singular values.")
+
+ @property
+ def scaling(self):
+ return self.dislora_alpha / self.r
+
+ @property
+ def __dict__(self):
+ return asdict(self)
+
+ def to_dict(self):
+ return self.__dict__
+
+ def save_pretrained(self, save_directory):
+ r"""
+ This method saves the configuration of your adapter model in a directory.
+ Args:
+ save_directory (`str`):
+ The directory where the configuration will be saved.
+ """
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ output_dict = self.__dict__
+ output_dict["scaling"] = self.scaling
+ output_path = os.path.join(save_directory, DISLORA_CONFIG_NAME)
+
+ # save it
+ with open(output_path, "w") as writer:
+ writer.write(json.dumps(output_dict, indent=2, sort_keys=True))
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ r"""
+ This method loads the configuration of your adapter model from a directory.
+ Args:
+ pretrained_model_name_or_path (`str`):
+ The directory or the hub-id where the configuration is saved.
+ **kwargs:
+ Additional keyword arguments passed along to the child class initialization.
+ """
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, DISLORA_CONFIG_NAME)):
+ config_file = os.path.join(pretrained_model_name_or_path, DISLORA_CONFIG_NAME)
+ else:
+ raise ValueError(f"Can't find dislora_config.json at '{pretrained_model_name_or_path}'")
+
+ loaded_attributes = cls.from_json_file(config_file)
+ loaded_attributes.pop("scaling", None)
+
+ merged_kwargs = {**loaded_attributes, **kwargs}
+ config = cls(**merged_kwargs)
+
+ return config
+
+ @classmethod
+ def from_json_file(cls, path_json_file):
+ r"""
+ Loads a configuration file from a json file.
+ Args:
+ path_json_file (`str`):
+ The path to the json file.
+ """
+ with open(path_json_file, "r") as file:
+ json_object = json.load(file)
+
+ return json_object
diff --git a/paddlenlp/peft/dislora/dislora_layer.py b/paddlenlp/peft/dislora/dislora_layer.py
new file mode 100644
index 000000000000..990d7629816f
--- /dev/null
+++ b/paddlenlp/peft/dislora/dislora_layer.py
@@ -0,0 +1,312 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+
+
+import warnings
+from typing import Union
+
+import paddle
+import paddle.nn as nn
+
+
+class DisLoRALinear(nn.Linear):
+ """
+ Paddle implementation of Direct Low-Rank Adaptation (DisLoRA) layer.
+ DisLoRA decomposes W into backbone (W_prin) and task-specific (W_res) subspaces via SVD,
+ further identifying task-specific directions (W_TSD) for fine tuning.
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ r: int = 8,
+ dislora_alpha: int = 8,
+ dislora_dropout: float = 0.0,
+ dash_flag: int = 50,
+ s_tsd: int = 8,
+ prefer_small_sigma: bool = True,
+ merge_weights: bool = False,
+ init_lora_weights: Union[bool, str] = True,
+ **kwargs
+ ):
+
+ if r <= 0:
+ raise ValueError(f"`r` must be a positive integer, got {r}")
+ if s_tsd <= 0:
+ raise ValueError(f"`s_tsd` must be a positive integer, got {s_tsd}")
+
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
+
+ original_weight = self.weight.clone()
+ original_bias = self.bias.clone() if self.bias is not None else None
+
+ self.base_dtype = original_weight.dtype
+
+ delattr(self, "weight")
+ if hasattr(self, "bias") and self.bias is not None:
+ delattr(self, "bias")
+
+ self.weight = self.create_parameter(
+ shape=[in_features, out_features],
+ default_initializer=nn.initializer.Assign(original_weight),
+ dtype=self.base_dtype,
+ attr=paddle.ParamAttr(trainable=False),
+ )
+
+ if original_bias is not None:
+ self.bias = self.create_parameter(
+ shape=[out_features],
+ default_initializer=nn.initializer.Assign(original_bias),
+ dtype=self.base_dtype,
+ attr=paddle.ParamAttr(trainable=True),
+ )
+ else:
+ self.bias = None
+
+ self.r = r
+ self.dislora_alpha = dislora_alpha
+ self.scaling = dislora_alpha / r
+ self.dislora_dropout = nn.Dropout(p=dislora_dropout) if dislora_dropout > 0.0 else nn.Identity()
+ self.dash_flag = dash_flag
+ self.s_tsd = s_tsd
+ self.prefer_small_sigma = prefer_small_sigma
+ self.merge_weights = merge_weights
+ self.init_lora_weights = init_lora_weights
+
+ self._disable_adapters = False
+ self.merged = False
+
+ self.register_buffer("step", paddle.to_tensor(0, dtype="int64"))
+
+ self.U = None
+ self.S = None
+ self.Vh = None
+
+ self.Direc_Ur = nn.Linear(r, out_features, bias_attr=False)
+ self.Direc_Sr = self.create_parameter(
+ shape=[r], default_initializer=nn.initializer.Constant(0.0), dtype=self.base_dtype
+ )
+ self.Direc_Vhr = nn.Linear(in_features, r, bias_attr=False)
+ self.Direc_Ur.weight.stop_gradient = False
+ self.Direc_Sr.stop_gradient = False
+ self.Direc_Vhr.weight.stop_gradient = False
+
+ self.Direc_Utsd = nn.Linear(s_tsd, out_features, bias_attr=False)
+ self.Direc_Stsd = self.create_parameter(
+ shape=[s_tsd], default_initializer=nn.initializer.Constant(0.0), dtype=self.base_dtype
+ )
+ self.Direc_Vhtsd = nn.Linear(in_features, s_tsd, bias_attr=False)
+
+ self.Direc_Utsd.weight.stop_gradient = True
+ self.Direc_Vhtsd.weight.stop_gradient = True
+
+ self._align_dtypes()
+
+ if init_lora_weights:
+ self._init_lora_weights()
+
+ def _align_dtypes(self):
+ """Ensure that the data types of all parameters are consistent with those of the base layer."""
+ target_dtype = self.base_dtype
+
+ if self.Direc_Ur.weight.dtype != target_dtype:
+ self.Direc_Ur.weight.set_value(self.Direc_Ur.weight.astype(target_dtype))
+ if self.Direc_Vhr.weight.dtype != target_dtype:
+ self.Direc_Vhr.weight.set_value(self.Direc_Vhr.weight.astype(target_dtype))
+ if self.Direc_Utsd.weight.dtype != target_dtype:
+ self.Direc_Utsd.weight.set_value(self.Direc_Utsd.weight.astype(target_dtype))
+ if self.Direc_Vhtsd.weight.dtype != target_dtype:
+ self.Direc_Vhtsd.weight.set_value(self.Direc_Vhtsd.weight.astype(target_dtype))
+ if self.Direc_Sr.dtype != target_dtype:
+ self.Direc_Sr.set_value(self.Direc_Sr.astype(target_dtype))
+ if self.Direc_Stsd.dtype != target_dtype:
+ self.Direc_Stsd.set_value(self.Direc_Stsd.astype(target_dtype))
+
+ def _init_lora_weights(self):
+ """
+ Initialize LoRA weights using SVD
+ Decompose the original weight W into W_prin (frozen backbone) + W_res (trainable residual)
+ Note: The shape of the Linear weight in PaddlePaddle is [in_features, out_features]
+ """
+ weight_float32 = self.weight.astype("float32")
+
+ weight_transposed = weight_float32.T
+
+ U, S, Vh = paddle.linalg.svd(weight_transposed, full_matrices=False)
+
+ self.U = U.astype(self.base_dtype)
+ self.S = S.astype(self.base_dtype)
+ self.Vh = Vh.astype(self.base_dtype)
+
+ if self.prefer_small_sigma:
+ _, indices = paddle.topk(S, self.r, largest=False)
+ else:
+ _, indices = paddle.topk(S, self.r, largest=True)
+
+ self.Direc_Ur.weight.set_value(U[:, indices].T.astype(self.base_dtype))
+ self.Direc_Sr.set_value(S[indices].astype(self.base_dtype))
+
+ self.Direc_Vhr.weight.set_value(Vh[indices, :].T.astype(self.base_dtype))
+ self.Direc_Ur.weight.stop_gradient = False
+ self.Direc_Sr.stop_gradient = False
+ self.Direc_Vhr.weight.stop_gradient = False
+ self.Direc_Stsd.stop_gradient = False
+
+ S_diag = paddle.diag(self.Direc_Sr) # [r, r]
+ W_res_T = self.Direc_Ur.weight.T @ S_diag @ self.Direc_Vhr.weight.T # [out_features, in_features]
+ W_res = W_res_T.T * self.scaling # [in_features, out_features]
+
+ if W_res.shape != self.weight.shape:
+ raise ValueError(f"Expected W_res shape {self.weight.shape}, but got {W_res.shape}.")
+
+ self.weight.set_value(self.weight - W_res.astype(self.base_dtype))
+ self.weight.stop_gradient = True
+
+ def forward(self, x: paddle.Tensor) -> paddle.Tensor:
+ """
+ Forward propagation: W_prin @ x + W_res @ x + W_TSD @ x
+ - W_prin is calculated through the base_layer
+ - W_res is calculated through the trainable LoRA structure
+ - W_TSD is calculated through the frozen dynamic vector (after warmup)
+ """
+ if self._disable_adapters:
+ if self.merged:
+ self.unmerge()
+ return super().forward(x)
+
+ if self.merged:
+ return super().forward(x)
+
+ result = super().forward(x)
+
+ temp = self.dislora_dropout(x)
+ temp = self.Direc_Vhr(temp)
+ temp = temp * self.Direc_Sr
+ temp = self.Direc_Ur(temp)
+ result += temp * self.scaling
+
+ if self.step < self.dash_flag:
+ pass
+ elif self.step == self.dash_flag:
+ self._initialize_dynamic_vectors()
+ else:
+ temp = self.dislora_dropout(x)
+ temp = self.Direc_Vhtsd(temp)
+ temp = temp * self.Direc_Stsd
+ temp = self.Direc_Utsd(temp)
+ result += temp * self.scaling
+
+ if self.training:
+ with paddle.no_grad():
+ self.step += 1
+
+ return result
+
+ def _initialize_dynamic_vectors(self):
+ """
+ After the warm-up steps, initialize the dynamic singular vector W_TSD.
+ Based on the current change of W_res, select the most important s_tsd directions.
+ """
+ with paddle.no_grad():
+
+ S_diag = paddle.diag(self.Direc_Sr) # [r, r]
+ deltaW_T = self.Direc_Ur.weight.T @ S_diag @ self.Direc_Vhr.weight.T # [out_features, in_features]
+
+ delta_sigma = paddle.diag(self.U.T @ deltaW_T @ self.Vh.T)
+
+ top_indices = self.calculate_change_rate(
+ self.S, delta_sigma, self.s_tsd, largest=not self.prefer_small_sigma
+ )
+
+ self.Direc_Utsd.weight.set_value(self.U[:, top_indices].T.astype(self.base_dtype))
+ self.Direc_Stsd.set_value(self.S[top_indices].astype(self.base_dtype))
+ self.Direc_Vhtsd.weight.set_value(self.Vh[top_indices, :].T.astype(self.base_dtype))
+
+ self.Direc_Utsd.weight.stop_gradient = True
+ self.Direc_Vhtsd.weight.stop_gradient = True
+
+ def calculate_change_rate(self, a: paddle.Tensor, b: paddle.Tensor, s: int, largest: bool = True) -> paddle.Tensor:
+ """
+ Calculate the rate of change of singular values and
+ select the top-s index change_rate = |b| / (|a| + eps)
+ """
+ with paddle.no_grad():
+
+ change_rate = paddle.abs(b) / (paddle.abs(a) + 1e-8)
+
+ _, top_s_indices = paddle.topk(change_rate, s, largest=largest)
+ return top_s_indices
+
+ def merge(self):
+ """
+ Merge the trainable W_res into the base weights.
+ After merging: base_layer.weight = W_prin + W_res
+ Note: W_TSD remains frozen and does not participate in the merge.
+ """
+ if self.merged:
+ warnings.warn("Already merged. Nothing to do.")
+ return
+
+ if self.r > 0:
+
+ delta_weight = self.get_delta_weight()
+ orig_weights = self.weight.clone()
+ orig_weights += delta_weight
+ self.weight.set_value(orig_weights)
+
+ self.merged = True
+
+ def unmerge(self):
+ """
+ Remove the merging of W_res from the base weights.
+ After the merging is removed: base_layer.weight = W_prin
+ """
+ if not self.merged:
+ warnings.warn("Already unmerged. Nothing to do.")
+ return
+
+ if self.r > 0:
+ delta_weight = self.get_delta_weight()
+ self.weight.set_value(self.weight - delta_weight)
+
+ self.merged = False
+
+ def get_delta_weight(self) -> paddle.Tensor:
+ """
+ Calculate the trainable LoRA incremental weights
+ It consists of two parts:
+ 1. W_res = Ur @ diag(Sr) @ Vhr * scaling (transposed)
+ 2. W_tsd = Utsd @ diag(Stsd) @ Vhtsd * scaling (transposed)
+ Return the incremental weights with the shape of [in_features, out_features]
+ """
+
+ S_diag_r = paddle.diag(self.Direc_Sr) # [r, r]
+ delta_weight_T = self.Direc_Ur.weight.T @ S_diag_r @ self.Direc_Vhr.weight.T # [out_features, in_features]
+ delta_weight = delta_weight_T.T * self.scaling # [in_features, out_features]
+
+ if not paddle.all(self.Direc_Stsd == 0.0):
+ S_diag_tsd = paddle.diag(self.Direc_Stsd) # [s_tsd, s_tsd]
+ delta_weight_tsd_T = (
+ self.Direc_Utsd.weight.T @ S_diag_tsd @ self.Direc_Vhtsd.weight.T
+ ) # [out_features, in_features]
+ delta_weight += delta_weight_tsd_T.T * self.scaling # [in_features, out_features]
+
+ return delta_weight.astype(self.base_dtype)
+
+ def enable_adapters(self):
+ """Enable the adapter"""
+ self._disable_adapters = False
+
+ def disable_adapters(self):
+ """Disable adapter"""
+ self._disable_adapters = True
+
+ def __repr__(self) -> str:
+ rep = super().__repr__()
+ return rep
diff --git a/paddlenlp/peft/dislora/dislora_model.py b/paddlenlp/peft/dislora/dislora_model.py
new file mode 100644
index 000000000000..9be94c0e90bc
--- /dev/null
+++ b/paddlenlp/peft/dislora/dislora_model.py
@@ -0,0 +1,446 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import os
+import re
+from collections import OrderedDict
+from typing import Dict, Union
+
+import numpy as np
+import paddle
+import paddle.nn as nn
+from paddle.distributed.fleet.meta_parallel import PipelineLayer
+
+from paddlenlp.transformers import AutoConfig, PretrainedModel
+from paddlenlp.transformers.model_utils import _add_variant, dtype_guard
+from paddlenlp.utils.log import logger
+
+from ...utils.env import DISLORA_WEIGHTS_NAME
+from .dislora_config import DisLoRAConfig
+
+
+def get_dislora_layers():
+ from .dislora_layer import DisLoRALinear
+
+ return {
+ "DisLoRALinear": DisLoRALinear,
+ }
+
+
+dislora_layers = get_dislora_layers()
+DisLoRALinear = dislora_layers["DisLoRALinear"]
+AVAILABLE_LAYERS = [
+ DisLoRALinear,
+]
+
+
+class DisLoRAModel(nn.Layer):
+ restore_layer_map: Dict[nn.Layer, nn.Layer] = {
+ DisLoRALinear: nn.Linear,
+ }
+
+ def __init__(self, model, dislora_config: DisLoRAConfig) -> None:
+ super().__init__()
+ self.model_config = AutoConfig.from_pretrained(dislora_config.base_model_name_or_path)
+ self.quantized = False
+ self.dislora_config = dislora_config
+ self.dislora_split_mapping = {}
+ if self.dislora_config.dtype is None:
+ self.dislora_config.dtype = paddle.get_default_dtype()
+ with dtype_guard(self.dislora_config.dtype):
+ self.model = self.get_dislora_model(model, dislora_config)
+ self.is_pipelinemodel = False
+ if issubclass(type(self.model), PipelineLayer):
+ raise NotImplementedError("dislora don't support pipeline parallel now")
+ if dislora_config.tensor_parallel_degree > 1:
+ self.dislora_config.tensor_parallel_degree = -1
+ self.model.config.tensor_parallel_degree = -1
+ raise NotImplementedError("dislora don't support tensor parallel now")
+ # currently tensor_parallel_degree should all be set to -1.
+ self.forward = self.model.forward
+
+ logger.info("Mark only dislora and trainable_module as trainable.")
+ self.mark_only_dislora_as_trainable()
+
+ @classmethod
+ def from_pretrained(cls, model, dislora_path, **kwargs):
+ dislora_config = kwargs.pop("dislora_config", None)
+ # init dislora config & dislora model
+ if not isinstance(dislora_config, DisLoRAConfig):
+ dislora_config = DisLoRAConfig.from_pretrained(dislora_path)
+ # define a new variable to conserve original lora_config.tensor_parallel_degree value which will update while initializing lora model
+ dislora_config_tensor_parallel_degree = dislora_config.tensor_parallel_degree
+ dislora_model = cls(model, dislora_config)
+
+ # define dislora weight name
+ dislora_weight_name = DISLORA_WEIGHTS_NAME
+
+ # load and set dislora weight parameter
+ dislora_weight_path = os.path.join(dislora_path, dislora_weight_name)
+ if os.path.exists(dislora_weight_path):
+ # load dislora weight parameter
+ dislora_state_dict = paddle.load(dislora_weight_path, return_numpy=True)
+ logger.info(f"Loading the DisLoRA weights from {dislora_weight_path}")
+
+ if (
+ dislora_config_tensor_parallel_degree > 1
+ and dislora_config_tensor_parallel_degree != model.config.tensor_parallel_degree
+ ):
+ raise NotImplementedError(
+ f"{dislora_config_tensor_parallel_degree} is not equal to {model.config.tensor_parallel_degree}. Please merge DisLoRA weights first."
+ )
+ # set dislora state dict
+ dislora_model.set_state_dict(dislora_state_dict)
+ else:
+ logger.error(f"DisLoRA weights not found under {dislora_path}, creating DisLoRA weights from scratch")
+
+ return dislora_model
+
+ def set_state_dict(self, state_dict):
+ import warnings
+
+ warnings.filterwarnings(
+ action="ignore", message=".*Skip loading for.*", category=Warning, lineno=0, append=False
+ )
+ self.model.set_state_dict(state_dict)
+ logger.info("Load dislora weight successfully")
+
+ def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = False, **kwargs):
+ logger.info("save dislora pretrained")
+ save_model_config = kwargs.get("save_model_config", True)
+
+ variant = kwargs.get("variant", None)
+ is_main_process = kwargs.get("is_main_process", paddle.distributed.get_rank() == 0)
+
+ assert not os.path.isfile(
+ save_directory
+ ), f"Saving directory ({save_directory}) should be a directory, not a file"
+ os.makedirs(save_directory, exist_ok=True)
+
+ dislora_config_to_save = DisLoRAConfig(**self.dislora_config.to_dict())
+ trainable_state_dict = self.get_trainable_state_dict()
+
+ # save dislora weight
+ dislora_weight_name = _add_variant(DISLORA_WEIGHTS_NAME, variant)
+ weight_filename = os.path.join(save_directory, dislora_weight_name)
+ paddle.save(trainable_state_dict, weight_filename)
+
+ # save dislora config
+ if is_main_process:
+ dislora_config_to_save.save_pretrained(save_directory)
+ if save_model_config:
+ model_config_to_save = copy.deepcopy(self.model.config)
+ if merge_tensor_parallel:
+ model_config_to_save.tensor_parallel_degree = -1
+ model_config_to_save.save_pretrained(save_directory)
+
+ def _find_and_replace_module(self, model, module_name, dislora_config):
+
+ if any(dislora_keyword in module_name.lower() for dislora_keyword in ["dislora", "direc_"]):
+ logger.debug(f"Skipping {module_name} - appears to be a DisLoRA submodule")
+ return
+
+ try:
+ parent_module = model
+ attribute_chain = module_name.split(".")
+ for name in attribute_chain[:-1]:
+ parent_module = getattr(parent_module, name)
+ module = getattr(parent_module, attribute_chain[-1])
+ except AttributeError as e:
+ logger.error(f"Cannot access module {module_name}: {e}")
+ raise ValueError(f"Cannot access target module {module_name}: {e}")
+
+ if isinstance(module, nn.Linear):
+ logger.debug(f"Converting {module_name} from nn.Linear to DisLoRALinear")
+
+ try:
+ dislora_module = DisLoRALinear(
+ in_features=module.weight.shape[0],
+ out_features=module.weight.shape[1],
+ r=dislora_config.r,
+ dislora_alpha=dislora_config.dislora_alpha,
+ dislora_dropout=dislora_config.dislora_dropout,
+ dash_flag=dislora_config.dash_flag,
+ s_tsd=dislora_config.s_tsd,
+ prefer_small_sigma=dislora_config.prefer_small_sigma,
+ merge_weights=dislora_config.merge_weights,
+ bias_attr=False if module.bias is None else None,
+ init_lora_weights=False,
+ )
+
+ dislora_module.weight.set_value(module.weight)
+ if module.bias is not None:
+ dislora_module.bias.set_value(module.bias)
+
+ dislora_module._init_lora_weights()
+
+ setattr(parent_module, attribute_chain[-1], dislora_module)
+ logger.debug(f"Successfully replaced {module_name}")
+
+ except Exception as e:
+ logger.error(f"Failed to create DisLoRALinear for {module_name}: {e}")
+ raise ValueError(f"Failed to create DisLoRALinear for {module_name}: {e}")
+
+ elif isinstance(module, DisLoRALinear):
+ logger.debug(f"Module {module_name} is already a DisLoRALinear, skipping")
+
+ else:
+
+ module_type = type(module).__name__
+ if any(keyword in module_name.lower() for keyword in ["dislora_dropout", "direc_"]):
+ logger.debug(f"Skipping DisLoRA submodule {module_name} ({module_type})")
+ return
+ else:
+
+ error_msg = f"Target module {module_name} is {module_type}, not nn.Linear. DisLoRA can only replace nn.Linear modules."
+ logger.error(f"Cannot replace {module_name}: expected nn.Linear, got {module_type}")
+ raise ValueError(error_msg)
+
+ def _find_and_restore_module(self, module_name):
+ parent_module = self.model
+ attribute_chain = module_name.split(".")
+ for name in attribute_chain[:-1]:
+ parent_module = getattr(parent_module, name)
+ module = getattr(parent_module, attribute_chain[-1])
+ original_model_class = self.restore_layer_map[module.__class__]
+ original_module = original_model_class(in_features=module.weight.shape[0], out_features=module.weight.shape[1])
+ original_module.weight = module.weight
+
+ if isinstance(module, DisLoRALinear):
+ if not module.merged:
+ complete_weight = module.weight + module.get_delta_weight()
+ original_module.weight.set_value(complete_weight)
+ else:
+ original_module.weight.set_value(module.weight)
+ else:
+ original_module.weight.set_value(module.weight)
+
+ if module.bias is not None:
+ original_module.bias.set_value(module.bias)
+
+ setattr(parent_module, attribute_chain[-1], original_module)
+
+ def get_trainable_state_dict(self):
+ """
+ Obtain the required state dictionary to be saved, including:
+ 1. Trainable parameters (stop_gradient = False)
+ 2. Main weight W_prin (although frozen, must be saved)
+ 3. TSD direction parameters (although frozen, must be saved)
+ 4. QAT-related parameters
+ """
+ trainable_state_dict = OrderedDict()
+ for name, weight in self.model.state_dict().items():
+ # Save trainable parameters and QAT parameters
+ if not weight.stop_gradient or "activation_quanter" in name or "weight_quanter" in name:
+ trainable_state_dict[name] = weight
+ # Save the main branch weight W_prin (for critical fixes)
+ elif "weight" in name and any(layer_name in name for layer_name in [".weight"]) and "Direc_" not in name:
+ trainable_state_dict[name] = weight
+ logger.debug(f"Saving backbone weight: {name}")
+ # Save all TSD parameters (excluding Direc_Stsd)
+ elif any(tsd_param in name for tsd_param in ["Direc_Utsd", "Direc_Vhtsd"]):
+ trainable_state_dict[name] = weight
+ logger.debug(f"Saving TSD parameter: {name}")
+ # Save the bias parameters (if any)
+ elif "bias" in name and "Direc_" not in name:
+ trainable_state_dict[name] = weight
+ logger.debug(f"Saving bias parameter: {name}")
+
+ return trainable_state_dict
+
+ def print_trainable_parameters(self) -> None:
+ freeze_numel = 0
+ trainable_numel = 0
+ for _, weight in self.model.state_dict().items():
+ if weight.stop_gradient:
+ freeze_numel += np.prod(weight.shape)
+ else:
+ trainable_numel += np.prod(weight.shape)
+ logger.debug(
+ f"Frozen parameters: {freeze_numel:.2e} || Trainable parameters:{trainable_numel:.2e} || Total parameters:{freeze_numel+trainable_numel:.2e}|| Trainable:{trainable_numel / (freeze_numel+trainable_numel):.2%}"
+ )
+
+ def mark_only_dislora_as_trainable(self) -> None:
+ """
+ Mark only the parameters related to DisLoRA as trainable, while ensuring that the TSD parameters remain in a frozen state.
+ """
+
+ for full_param_name, weight in self.model.state_dict().items():
+
+ is_dislora_layer = any(
+ re.fullmatch(target_module, full_param_name.rsplit(".", 1)[0])
+ for target_module in self.dislora_config.target_modules
+ )
+
+ if is_dislora_layer:
+ param_name = full_param_name.split(".")[-1]
+
+ if param_name == "weight" and "Direc_" not in full_param_name:
+ weight.stop_gradient = True
+ logger.debug(f"Freezing backbone weight: {full_param_name}")
+
+ elif param_name == "bias" and "Direc_" not in full_param_name:
+ if self.dislora_config.trainable_bias in ["dislora", "all"]:
+ weight.stop_gradient = False
+ logger.debug(f"Setting bias as trainable: {full_param_name}")
+ else:
+ weight.stop_gradient = True
+ logger.debug(f"Freezing bias: {full_param_name}")
+
+ elif any(tsd_param in full_param_name for tsd_param in ["Direc_Utsd", "Direc_Vhtsd"]):
+ weight.stop_gradient = True
+ logger.debug(f"Keeping TSD parameter frozen: {full_param_name}")
+
+ elif any(
+ trainable_param in full_param_name
+ for trainable_param in ["Direc_Ur", "Direc_Sr", "Direc_Vhr", "Direc_Stsd"]
+ ):
+ weight.stop_gradient = False
+ logger.debug(f"Setting DisLoRA parameter as trainable: {full_param_name}")
+
+ else:
+ weight.stop_gradient = True
+ logger.debug(f"Freezing other parameter: {full_param_name}")
+
+ else:
+ param_name = full_param_name.split(".")[-1]
+ if self.dislora_config.trainable_bias == "all" and param_name == "bias":
+ weight.stop_gradient = False
+ logger.debug(f"Setting bias as trainable in non-DisLoRA layer: {full_param_name}")
+ else:
+ weight.stop_gradient = True
+ logger.debug(f"Freezing parameter in non-DisLoRA layer: {full_param_name}")
+
+ if self.dislora_config.trainable_modules is not None:
+ for full_param_name, weight in self.model.state_dict().items():
+ if any(
+ re.fullmatch(trainable_module, full_param_name)
+ for trainable_module in self.dislora_config.trainable_modules
+ ):
+
+ if not any(tsd_param in full_param_name for tsd_param in ["Direc_Utsd", "Direc_Vhtsd"]):
+ weight.stop_gradient = False
+ logger.debug(f"Setting additional trainable module parameter: {full_param_name}")
+ else:
+ logger.warning(
+ f"TSD parameter {full_param_name} matched trainable_modules pattern but kept frozen"
+ )
+
+ def get_dislora_model(self, model: Union[PretrainedModel, nn.Layer], dislora_config: DisLoRAConfig):
+ """
+ Iterate all base model layers, change target modules to DisLoRALayer.
+ """
+ if dislora_config.target_modules is None:
+ return model
+ else:
+ target_modules = dislora_config.target_modules
+
+ target_module_names = []
+
+ existing_dislora_paths = set()
+ for module_name, module in model.named_sublayers():
+ if isinstance(module, DisLoRALinear):
+ existing_dislora_paths.add(module_name)
+
+ for target_module in target_modules:
+ for module_name, module in model.named_sublayers():
+
+ if re.fullmatch(target_module, module_name):
+
+ if not isinstance(module, DisLoRALinear):
+
+ is_submodule = any(
+ module_name.startswith(dislora_path + ".") for dislora_path in existing_dislora_paths
+ )
+
+ if not is_submodule:
+ target_module_names.append(module_name)
+ else:
+ logger.debug(f"Skipping {module_name} - it's a submodule of existing DisLoRA module")
+ else:
+ logger.debug(f"Skipping {module_name} - already a DisLoRA module")
+
+ for module_name in target_module_names:
+ try:
+ self._find_and_replace_module(model, module_name, dislora_config)
+ logger.debug(f"Replaced {module_name} with DisLoRALinear")
+ except ValueError as e:
+ raise e
+ except Exception as e:
+
+ logger.warning(f"Failed to replace {module_name}: {e}")
+
+ return model
+
+ def restore_original_model(self):
+ # make sure W and dislora weights are not merged before we restore the original model
+ for layer_name, layer in self.model.named_sublayers():
+ if isinstance(layer, DisLoRALinear):
+ self._find_and_restore_module(layer_name)
+ return self.model
+
+ def __getattr__(self, name: str):
+ """
+ Forward missing attributes to the wrapped module.
+ """
+ try:
+ return super().__getattr__(name) # defer to nn.Layer's logic
+ except AttributeError:
+ return getattr(self.model, name)
+
+ def train(self):
+ self.training = True
+ self.model.training = True
+ for layer in self.model.sublayers():
+ layer.training = True
+ layer.train()
+
+ def eval(self):
+ self.training = False
+ self.model.training = False
+ for layer in self.model.sublayers():
+ layer.training = False
+ layer.eval()
+
+ def disable_dislora(self):
+ """
+ Disable the DisLoRA adapter
+ """
+ for _, layer in self.model.named_sublayers():
+ if isinstance(layer, DisLoRALinear):
+ layer.disable_adapters()
+
+ def enable_dislora(self):
+ """
+ Enable the DisLoRA adapter
+ """
+ for _, layer in self.model.named_sublayers():
+ if isinstance(layer, DisLoRALinear):
+ layer.enable_adapters()
+
+ def merge(self):
+ for _, layer in self.model.named_sublayers():
+ if any(isinstance(layer, dislora_layer) for dislora_layer in AVAILABLE_LAYERS):
+ layer.merge()
+
+ def unmerge(self):
+ for _, layer in self.model.named_sublayers():
+ if any(isinstance(layer, dislora_layer) for dislora_layer in AVAILABLE_LAYERS):
+ layer.unmerge()
+
+ def get_model_config(
+ self,
+ ):
+ return self.model_config.to_dict()
diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py
index b48679dbf26a..cb7ff1454c96 100644
--- a/paddlenlp/trainer/trainer.py
+++ b/paddlenlp/trainer/trainer.py
@@ -82,7 +82,14 @@
default_data_collator,
init_dataloader_comm_group,
)
-from ..peft import LoKrModel, LoRAModel, PrefixModelForCausalLM, ReFTModel, VeRAModel
+from ..peft import (
+ DisLoRAModel,
+ LoKrModel,
+ LoRAModel,
+ PrefixModelForCausalLM,
+ ReFTModel,
+ VeRAModel,
+)
from ..quantization.quantization_linear import (
ColumnParallelQuantizationLinear,
QuantizationLinear,
@@ -107,6 +114,7 @@
from ..transformers.tokenizer_utils import PretrainedTokenizer
from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler
from ..utils.env import (
+ DISLORA_WEIGHTS_NAME,
LOKR_WEIGHTS_NAME,
LORA_WEIGHTS_NAME,
MODEL_META_NAME,
@@ -464,6 +472,7 @@ def _save_ckpt_func(state_dict, path, signal_path=None):
or isinstance(self.model, PrefixModelForCausalLM)
or isinstance(self.model, VeRAModel)
or isinstance(self.model, LoKrModel)
+ or isinstance(self.model, DisLoRAModel)
or isinstance(self.model, ReFTModel)
):
if self.args.unified_checkpoint and "skip_save_model_weight" in self.args.unified_checkpoint_config:
@@ -616,6 +625,8 @@ def _load_from_peft_checkpoint(self, resume_from_checkpoint=None):
weights_file = os.path.join(resume_from_checkpoint, VERA_WEIGHTS_NAME)
elif isinstance(self.model, LoKrModel):
weights_file = os.path.join(resume_from_checkpoint, LOKR_WEIGHTS_NAME)
+ elif isinstance(self.model, DisLoRAModel):
+ weights_file = os.path.join(resume_from_checkpoint, DISLORA_WEIGHTS_NAME)
elif isinstance(self.model, ReFTModel):
self.model.from_pretrained(resume_from_checkpoint, self.model.model)
return
@@ -681,6 +692,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
or isinstance(self.model, PrefixModelForCausalLM)
or isinstance(self.model, VeRAModel)
or isinstance(self.model, LoKrModel)
+ or isinstance(self.model, DisLoRAModel)
or isinstance(self.model, ReFTModel)
):
self._load_from_peft_checkpoint(resume_from_checkpoint)
@@ -2996,6 +3008,7 @@ def _save(
or isinstance(self.model, PrefixModelForCausalLM)
or isinstance(self.model, VeRAModel)
or isinstance(self.model, LoKrModel)
+ or isinstance(self.model, DisLoRAModel)
or isinstance(self.model, ReFTModel)
):
self.model.save_pretrained(
diff --git a/paddlenlp/trl/model_config.py b/paddlenlp/trl/model_config.py
index 2e244d211158..66a3be6a38d6 100644
--- a/paddlenlp/trl/model_config.py
+++ b/paddlenlp/trl/model_config.py
@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass, field
-from typing import Optional
+from typing import List, Optional
__all__ = ["ModelConfig"]
@@ -90,6 +90,27 @@ class ModelConfig:
)
lokr_dim: int = field(default=8, metadata={"help": "Lora dimension in LoKr dimension for adapter matrix"})
+ # dislora related parameters
+ dislora: bool = field(default=False, metadata={"help": "Whether to use dislora technique"})
+ dislora_path: str = field(default=None, metadata={"help": "Initialize dislora state dict."})
+ dislora_rank: int = field(default=8, metadata={"help": "DisLoRA attention dimension"})
+ dislora_dropout: float = field(default=0.05, metadata={"help": "DisLoRA dropout"})
+ target_modules: Optional[List[str]] = field(
+ default=None,
+ metadata={"help": "Custom target modules for DisLoRA. If None, will use default modules based on model type."},
+ )
+ dash_flag: int = field(
+ default=50, metadata={"help": "The number of preheating steps before introducing additional low-rank updates"}
+ )
+ s_tsd: int = field(
+ default=8, metadata={"help": "The number of top-k singular vectors dynamically selected after preheating"}
+ )
+ ortho_lambda: float = field(default=1, metadata={"help": "The weight of orthogonal regularization loss"})
+ prefer_small_sigma: bool = field(
+ default=True,
+ metadata={"help": "Whether to prioritize the smallest singular value in the top-k selection process"},
+ )
+
# prefix tuning related parameters
prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"})
prefix_path: str = field(default=None, metadata={"help": "Initialize prefix state dict."})
diff --git a/paddlenlp/trl/sft_config.py b/paddlenlp/trl/sft_config.py
index f759bc68a1aa..be8152e4bb90 100644
--- a/paddlenlp/trl/sft_config.py
+++ b/paddlenlp/trl/sft_config.py
@@ -75,6 +75,10 @@ class SFTConfig(TrainingArguments):
"help": "The ratio parameter for grouping in SSA, controlling the number of tokens considered in each group for sparse attention calculation."
},
)
+ dislora_ortho_lambda: float = field(
+ default=0.0,
+ metadata={"help": "Orthogonal regularization weight for DisLoRA. Set to 1 for Pareto optimization."},
+ )
def __post_init__(self):
super().__post_init__()
diff --git a/paddlenlp/trl/sft_trainer.py b/paddlenlp/trl/sft_trainer.py
index 8466bfc7abba..fd1d6ce6be65 100644
--- a/paddlenlp/trl/sft_trainer.py
+++ b/paddlenlp/trl/sft_trainer.py
@@ -419,3 +419,70 @@ def ptq_loop(
self.prediction_step(model=self.model, inputs=inputs, prediction_loss_only=True, ignore_keys=None)
if max_eval_iters > 0 and step >= max_eval_iters - 1:
break
+
+ def _calc_ortho_loss(self, model):
+ """Calculate the orthogonal constraint loss of DisLoRA"""
+ import paddle
+
+ ortho_loss = 0.0
+ den = 0
+
+ for name, param in model.named_parameters():
+ if "Direc_Ur" in name and "weight" in name:
+ u = param
+ iu = paddle.eye(u.shape[0], dtype=u.dtype)
+ u_loss = paddle.norm(u @ u.T - iu, p="fro")
+ ortho_loss += u_loss
+ den += 1
+
+ elif "Direc_Vhr" in name and "weight" in name:
+ vh = param
+ ivh = paddle.eye(vh.shape[1], dtype=vh.dtype)
+ vh_loss = paddle.norm(vh.T @ vh - ivh, p="fro")
+ ortho_loss += vh_loss
+ den += 1
+
+ if den > 0:
+ return ortho_loss / den
+ else:
+ return None
+
+ def compute_loss(self, model, inputs, return_outputs=False):
+ """Override compute_loss to add DisLoRA orthogonal regularization"""
+ import paddle
+
+ result = super().compute_loss(model, inputs, return_outputs=False)
+
+ if isinstance(result, tuple):
+ loss = result[0]
+ outputs = result[1] if len(result) > 1 else None
+ else:
+ loss = result
+ outputs = None
+
+ if isinstance(loss, tuple):
+ loss = loss[0]
+
+ if hasattr(self.args, "dislora_ortho_lambda") and self.args.dislora_ortho_lambda > 0:
+ ortho_loss = self._calc_ortho_loss(model)
+
+ if ortho_loss is not None and loss is not None:
+
+ if loss.numel() > 1:
+ loss = loss.mean()
+ if ortho_loss.numel() > 1:
+ ortho_loss = ortho_loss.mean()
+
+ if abs(self.args.dislora_ortho_lambda - 1.0) < 1e-6:
+
+ with paddle.no_grad():
+ ratio = ortho_loss / (loss + 1e-8)
+ alpha_task = paddle.exp(-ratio) / (paddle.exp(-ratio) + paddle.exp(-1 / ratio))
+ alpha_ortho = 1.0 - alpha_task
+
+ loss = alpha_task * loss + alpha_ortho * ortho_loss
+ else:
+
+ loss = loss + self.args.dislora_ortho_lambda * ortho_loss
+
+ return (loss, outputs) if return_outputs else loss
diff --git a/paddlenlp/utils/env.py b/paddlenlp/utils/env.py
index 62503e09a39e..0489d6b6d6cc 100644
--- a/paddlenlp/utils/env.py
+++ b/paddlenlp/utils/env.py
@@ -111,6 +111,9 @@ def _get_bool_env(env_key: str, default_value: str) -> bool:
LOKR_WEIGHTS_NAME = "lokr_model_state.pdparams"
LOKR_CONFIG_NAME = "lokr_config.json"
+DISLORA_WEIGHTS_NAME = "dislora_model_state.pdparams"
+DISLORA_CONFIG_NAME = "dislora_config.json"
+
PAST_KEY_VALUES_FILE_NAME = "pre_caches.npy"
PADDLE_WEIGHTS_NAME = "model_state.pdparams"
diff --git a/tests/fixtures/llm/dislora.yaml b/tests/fixtures/llm/dislora.yaml
new file mode 100644
index 000000000000..15500928b97d
--- /dev/null
+++ b/tests/fixtures/llm/dislora.yaml
@@ -0,0 +1,78 @@
+dislora:
+ base:
+ dataset_name_or_path: "./data"
+ per_device_train_batch_size: 1
+ gradient_accumulation_steps: 5
+ per_device_eval_batch_size: 8
+ eval_accumulation_steps: 16
+ num_train_epochs: 1
+ learning_rate: 2e-05
+ lr_scheduler_type: linear
+ warmup_steps: 30
+ logging_steps: 1
+ evaluation_strategy: "no"
+ save_strategy: "steps"
+ save_steps: 500
+ src_length: 256
+ max_length: 256
+ fp16: true
+ fp16_opt_level: "O2"
+ do_train: true
+ do_eval: false
+ disable_tqdm: false
+ load_best_model_at_end: false
+ eval_with_do_generation: false
+ recompute: false
+ save_total_limit: 5
+ sharding: "stage3"
+ zero_padding: false
+ use_flash_attention: false
+ unified_checkpoint: false
+ tensor_parallel_degree: 1
+ pipeline_parallel_degree: 1
+ dislora: true
+ dislora_rank: 8
+ dislora_dropout: 0.05
+
+ s_tsd: 8
+ ortho_lambda: 1.0
+ prefer_small_sigma: true
+
+ default:
+ llama:
+ model_name_or_path: __internal_testing__/tiny-random-llama
+ chatglm:
+ model_name_or_path: __internal_testing__/tiny-fused-chatglm
+ chatglm2:
+ model_name_or_path: __internal_testing__/tiny-fused-chatglm2
+ bloom:
+ model_name_or_path: __internal_testing__/tiny-fused-bloom
+ qwen:
+ model_name_or_path: __internal_testing__/tiny-fused-qwen
+ qwen2:
+ model_name_or_path: __internal_testing__/tiny-random-qwen2
+ qwen2moe:
+ model_name_or_path: __internal_testing__/tiny-random-qwen2moe
+ baichuan:
+ model_name_or_path: __internal_testing__/tiny-fused-baichuan
+
+inference-predict:
+ default:
+ mode: dynamic
+ max_length: 20
+ batch_size: 2
+ decode_strategy: greedy_search
+ dtype: float16
+
+inference-to-static:
+ default:
+ dtype: float16
+ max_length: 20
+
+inference-infer:
+ default:
+ mode: static
+ dtype: float16
+ batch_size: 2
+ decode_strategy: greedy_search
+ max_length: 20
\ No newline at end of file
diff --git a/tests/llm/test_dislora.py b/tests/llm/test_dislora.py
new file mode 100644
index 000000000000..957962f71ab6
--- /dev/null
+++ b/tests/llm/test_dislora.py
@@ -0,0 +1,82 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+import os
+import sys
+import unittest
+
+import paddle
+from parameterized import parameterized_class
+
+from tests.testing_utils import argv_context_guard, load_test_config
+
+from .testing_utils import LLMTest
+
+
+@parameterized_class(
+ ["model_dir"],
+ [
+ ["llama"],
+ ["chatglm"],
+ ["chatglm2"],
+ ["bloom"],
+ ["qwen"],
+ ["baichuan"],
+ ],
+)
+class DisLoRATest(LLMTest, unittest.TestCase):
+ config_path: str = "./tests/fixtures/llm/dislora.yaml"
+ model_dir: str = None
+
+ def setUp(self) -> None:
+ LLMTest.setUp(self)
+
+ self.model_codes_dir = os.path.join(self.root_path, self.model_dir)
+ sys.path.insert(0, self.model_codes_dir)
+
+ def tearDown(self) -> None:
+ LLMTest.tearDown(self)
+ sys.path.remove(self.model_codes_dir)
+
+ def test_dislora(self):
+ self.disable_static()
+ paddle.set_default_dtype("float32")
+
+ dislora_config = load_test_config(self.config_path, "dislora", self.model_dir)
+ dislora_config["output_dir"] = self.output_dir
+ dislora_config["dataset_name_or_path"] = self.data_dir
+
+ with argv_context_guard(dislora_config):
+ from run_finetune import main
+
+ main()
+
+ # merge weights
+ merge_dislora_weights_config = {
+ "dislora_path": dislora_config["output_dir"],
+ "merge_dislora_model_path": dislora_config["output_dir"],
+ "device": "gpu",
+ "low_gpu_mem": True,
+ }
+ with argv_context_guard(merge_dislora_weights_config):
+ from tools.merge_dislora_params import merge
+
+ merge()
+
+ # # TODO(wj-Mcat): disable chatglm2 test temporarily
+ # if self.model_dir not in ["qwen", "baichuan", "chatglm2"]:
+ # self.run_predictor({"inference_model": True})
+
+ self.run_predictor({"inference_model": False})
diff --git a/tests/peft/test_dislora.py b/tests/peft/test_dislora.py
new file mode 100644
index 000000000000..fb5e8db8396e
--- /dev/null
+++ b/tests/peft/test_dislora.py
@@ -0,0 +1,232 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import os
+import re
+import unittest
+from tempfile import TemporaryDirectory
+
+import numpy as np
+import paddle
+from parameterized import parameterized
+
+from paddlenlp.peft.dislora import DisLoRAConfig, DisLoRALinear, DisLoRAModel
+from paddlenlp.transformers import AutoModel, BertModel
+
+
+class TestDisLoRALayer(unittest.TestCase):
+ def test_r_raise_exception(self):
+ with self.assertRaises(ValueError):
+ DisLoRALinear(in_features=16, out_features=8, r=0, dislora_alpha=8)
+
+ def test_forward(self):
+ # r=8, dislora_alpha=12 (1.5 * 8)
+ dislora_layer = DisLoRALinear(in_features=16, out_features=8, r=8, dislora_dropout=0.1, dislora_alpha=12)
+ x = paddle.randn([2, 4, 16], "float32")
+ output = dislora_layer(x)
+
+ # Check the trainable DisLoRA parameters (related to W_res)
+ self.assertFalse(dislora_layer.Direc_Ur.weight.stop_gradient)
+ self.assertFalse(dislora_layer.Direc_Vhr.weight.stop_gradient)
+ self.assertFalse(dislora_layer.Direc_Sr.stop_gradient)
+ self.assertFalse(dislora_layer.Direc_Stsd.stop_gradient)
+
+ # Check the frozen TSD parameters
+ self.assertTrue(dislora_layer.Direc_Utsd.weight.stop_gradient)
+ self.assertTrue(dislora_layer.Direc_Vhtsd.weight.stop_gradient)
+
+ # Check the frozen main branch weights W_prin
+ self.assertTrue(dislora_layer.weight.stop_gradient)
+
+ # Check the bias parameters (by default, they should be trainable, but this depends on the configuration)
+ if dislora_layer.bias is not None:
+ self.assertFalse(dislora_layer.bias.stop_gradient)
+
+ self.assertEqual(output.shape, [2, 4, 8])
+
+ def test_train_eval(self):
+ x = paddle.randn([2, 4, 16], "float32")
+
+ dislora_layer = DisLoRALinear(in_features=16, out_features=8, r=8, dislora_alpha=12)
+ dislora_layer.train()
+ train_result = dislora_layer(x)
+ train_weight = copy.deepcopy(dislora_layer.weight)
+ dislora_layer.eval()
+ eval_result = dislora_layer(x)
+ eval_weight = dislora_layer.weight
+ self.assertTrue(paddle.allclose(train_result, eval_result))
+ self.assertTrue(paddle.allclose(train_weight, eval_weight))
+
+ def test_save_load(self):
+ with TemporaryDirectory() as tempdir:
+
+ dislora_layer = DisLoRALinear(in_features=16, out_features=8, r=8, dislora_alpha=12)
+ weights_path = os.path.join(tempdir, "model.pdparams")
+ paddle.save(dislora_layer.state_dict(), weights_path)
+
+ new_dislora_layer = DisLoRALinear(in_features=16, out_features=8, r=8, dislora_alpha=12)
+ state_dict = paddle.load(weights_path)
+ new_dislora_layer.set_dict(state_dict)
+ x = paddle.randn([2, 4, 16], "float32")
+ self.assertTrue(paddle.allclose(new_dislora_layer(x), dislora_layer(x)))
+
+ def test_load_regular_linear(self):
+ with TemporaryDirectory() as tempdir:
+ regular_linear = paddle.nn.Linear(in_features=16, out_features=12)
+ weights_path = os.path.join(tempdir, "model.pdparams")
+ paddle.save(regular_linear.state_dict(), weights_path)
+ state_dict = paddle.load(weights_path)
+ # should be identical to regular linear
+
+ dislora_layer_r8 = DisLoRALinear(
+ in_features=16, out_features=12, r=8, dislora_alpha=12, init_lora_weights=False
+ )
+
+ dislora_layer_r10 = DisLoRALinear(
+ in_features=16, out_features=12, r=10, dislora_alpha=15, init_lora_weights=False
+ )
+
+ # Load regular linear weights first
+ filtered_state_dict = {k: v for k, v in state_dict.items() if k in ["weight", "bias"]}
+ dislora_layer_r8.set_dict(filtered_state_dict)
+ dislora_layer_r10.set_dict(filtered_state_dict)
+
+ # Then perform SVD initialization
+ dislora_layer_r8._init_lora_weights()
+ dislora_layer_r10._init_lora_weights()
+
+ x = paddle.randn([2, 4, 16], "float32")
+
+ diff_r8 = paddle.abs(dislora_layer_r8(x) - regular_linear(x))
+ print(f"R8 - Max diff: {paddle.max(diff_r8).item():.6e}, Mean diff: {paddle.mean(diff_r8).item():.6e}")
+ self.assertTrue(paddle.allclose(dislora_layer_r8(x), regular_linear(x), atol=2e-3))
+ # Update variable name
+ self.assertTrue(paddle.allclose(dislora_layer_r10(x), regular_linear(x), atol=2e-3))
+
+
+class TestDisLoRAModel(unittest.TestCase):
+ def test_dislora_model_restore(self):
+
+ dislora_config = DisLoRAConfig(
+ target_modules=[".*q_proj.*", ".*v_proj.*"],
+ r=8,
+ dislora_alpha=12,
+ base_model_name_or_path="__internal_testing__/tiny-random-bert",
+ )
+ model = AutoModel.from_pretrained("__internal_testing__/tiny-random-bert")
+ input_ids = paddle.to_tensor(np.random.randint(100, 200, [1, 20]))
+ model.eval()
+ original_results_1 = model(input_ids)
+ dislora_model = DisLoRAModel(model, dislora_config)
+ restored_model = dislora_model.restore_original_model()
+ restored_model.eval()
+ original_results_2 = restored_model(input_ids)
+ self.assertIsNotNone(original_results_1)
+ self.assertIsNotNone(original_results_2)
+ self.assertIsInstance(restored_model, BertModel)
+ self.assertTrue(paddle.allclose(original_results_1[0], original_results_2[0]))
+
+ @parameterized.expand([(None,), ("all",), ("dislora",)])
+ def test_dislora_model_constructor(self, bias):
+
+ dislora_config = DisLoRAConfig(
+ target_modules=[".*q_proj.*", ".*v_proj.*"],
+ r=8,
+ dislora_alpha=12,
+ trainable_bias=bias,
+ base_model_name_or_path="__internal_testing__/tiny-random-bert",
+ )
+ model = AutoModel.from_pretrained(
+ "__internal_testing__/tiny-random-bert", hidden_dropout_prob=0, attention_probs_dropout_prob=0
+ )
+ dislora_model = DisLoRAModel(model, dislora_config)
+ dislora_model.mark_only_dislora_as_trainable()
+ for name, weight in dislora_model.state_dict().items():
+ if any([re.fullmatch(target_module, name) for target_module in dislora_config.target_modules]):
+ if any(
+ [dislora_param in name for dislora_param in ["Direc_Ur", "Direc_Sr", "Direc_Vhr", "Direc_Stsd"]]
+ ):
+ self.assertFalse(weight.stop_gradient)
+ elif any([tsd_param in name for tsd_param in ["Direc_Utsd", "Direc_Vhtsd"]]):
+ self.assertTrue(weight.stop_gradient)
+ elif "bias" in name and bias in ["dislora", "all"]:
+ self.assertFalse(weight.stop_gradient)
+ else:
+ self.assertTrue(weight.stop_gradient)
+ else:
+ if "bias" in name and bias == "all":
+ self.assertFalse(weight.stop_gradient)
+ else:
+ self.assertTrue(weight.stop_gradient)
+
+ input_ids = paddle.to_tensor(np.random.randint(100, 200, [1, 20]))
+ dislora_model.train()
+ train_forward_results = dislora_model(input_ids)
+ self.assertIsNotNone(train_forward_results)
+ dislora_model.eval()
+ eval_forward_results = dislora_model(input_ids)
+ self.assertIsNotNone(eval_forward_results)
+ self.assertTrue(paddle.allclose(train_forward_results[0], eval_forward_results[0]))
+
+ def test_dislora_model_save_load(self):
+ with TemporaryDirectory() as tempdir:
+ input_ids = paddle.to_tensor(np.random.randint(100, 200, [1, 20]))
+
+ dislora_config = DisLoRAConfig(
+ target_modules=[".*q_proj.*", ".*v_proj.*"],
+ r=8,
+ dislora_alpha=12,
+ base_model_name_or_path="__internal_testing__/tiny-random-bert",
+ )
+ model = AutoModel.from_pretrained("__internal_testing__/tiny-random-bert")
+ dislora_model = DisLoRAModel(model, dislora_config)
+ dislora_model.eval()
+ original_results = dislora_model(input_ids)
+ dislora_model.save_pretrained(tempdir)
+
+ loaded_dislora_model = DisLoRAModel.from_pretrained(model, tempdir)
+ loaded_dislora_model.eval()
+ loaded_results = loaded_dislora_model(input_ids)
+ self.assertTrue(paddle.allclose(original_results[0], loaded_results[0]))
+
+ config_loaded_dislora_model = DisLoRAModel.from_pretrained(model, tempdir, dislora_config=dislora_config)
+ config_loaded_dislora_model.eval()
+ config_loaded_results = config_loaded_dislora_model(input_ids)
+ self.assertTrue(paddle.allclose(original_results[0], config_loaded_results[0]))
+
+ def test_dislora_module_raise_exception(self):
+
+ dislora_config = DisLoRAConfig(
+ target_modules=[".*norm1.*"],
+ r=8,
+ dislora_alpha=12,
+ base_model_name_or_path="__internal_testing__/tiny-random-bert",
+ )
+ model = AutoModel.from_pretrained("__internal_testing__/tiny-random-bert")
+ with self.assertRaises(ValueError):
+ DisLoRAModel(model, dislora_config)
+
+
+class TestDisLoRAConfig(unittest.TestCase):
+ def test_save_load(self):
+ with TemporaryDirectory() as tempdir:
+ # Set r and dislora_alpha explicitly
+ dislora_config = DisLoRAConfig(target_modules=["test"], r=8, dislora_alpha=12)
+ dislora_config.save_pretrained(tempdir)
+ loaded_dislora_config = DisLoRAConfig.from_pretrained(tempdir)
+ self.assertEqual(dislora_config.r, loaded_dislora_config.r)
+ self.assertEqual(dislora_config.dislora_alpha, loaded_dislora_config.dislora_alpha)
+ self.assertEqual(dislora_config.dash_flag, loaded_dislora_config.dash_flag)
+ self.assertEqual(dislora_config.s_tsd, loaded_dislora_config.s_tsd)
From dcf306dee3676a90bbf2f514f1a7c176f751002e Mon Sep 17 00:00:00 2001
From: Pioneer-wxh <2274246074@qq.com>
Date: Fri, 17 Oct 2025 22:02:34 +0800
Subject: [PATCH 2/4] add DisLoRATrainer
---
llm/run_finetune.py | 35 +++++++-----
paddlenlp/trl/__init__.py | 1 +
paddlenlp/trl/dislora_trainer.py | 94 ++++++++++++++++++++++++++++++++
paddlenlp/trl/sft_trainer.py | 67 -----------------------
4 files changed, 116 insertions(+), 81 deletions(-)
create mode 100644 paddlenlp/trl/dislora_trainer.py
diff --git a/llm/run_finetune.py b/llm/run_finetune.py
index 37afa9e4528f..8c664befeaf1 100644
--- a/llm/run_finetune.py
+++ b/llm/run_finetune.py
@@ -70,7 +70,7 @@
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
from paddlenlp.transformers.longlora import replace_llama_attn, set_group_size
-from paddlenlp.trl import DataConfig, ModelConfig, SFTConfig, SFTTrainer
+from paddlenlp.trl import DataConfig, DisLoRATrainer, ModelConfig, SFTConfig, SFTTrainer
from paddlenlp.trl.llm_utils import (
ZeroPaddingIterDatasetCallback,
compute_metrics,
@@ -458,19 +458,26 @@ def compute_metrics_do_generation(eval_preds):
if model_args.dislora and hasattr(model_args, "ortho_lambda"):
training_args.dislora_ortho_lambda = model_args.ortho_lambda
- trainer = SFTTrainer(
- model=model,
- args=training_args,
- train_dataset=train_ds,
- eval_dataset=dev_ds,
- tokenizer=tokenizer,
- compute_metrics=metrics,
- data_collator=data_collator_fn if not model_args.reft else ReftDataCollator(data_collator=data_collator_fn),
- do_generation=data_args.eval_with_do_generation,
- callbacks=[ZeroPaddingIterDatasetCallback()] if isinstance(train_ds, ZeroPaddingIterableDataset) else None,
- gen_args=gen_args,
- data_args=data_args,
- )
+ trainer_kwargs = {
+ "model": model,
+ "args": training_args,
+ "train_dataset": train_ds,
+ "eval_dataset": dev_ds,
+ "tokenizer": tokenizer,
+ "compute_metrics": metrics,
+ "data_collator": data_collator_fn if not model_args.reft else ReftDataCollator(data_collator=data_collator_fn),
+ "do_generation": data_args.eval_with_do_generation,
+ "callbacks": [ZeroPaddingIterDatasetCallback()] if isinstance(train_ds, ZeroPaddingIterableDataset) else None,
+ "gen_args": gen_args,
+ "data_args": data_args,
+ }
+
+ if model_args.dislora:
+ logger.info("Using DisLoRATrainer for training.")
+ trainer = DisLoRATrainer(**trainer_kwargs)
+ else:
+ trainer = SFTTrainer(**trainer_kwargs)
+
trainable_parameters = [
p for p in model.parameters() if not p.stop_gradient or ("quantization_linear" in p.name and "w_1" in p.name)
]
diff --git a/paddlenlp/trl/__init__.py b/paddlenlp/trl/__init__.py
index 258c16a53bb2..a38eaf50d140 100644
--- a/paddlenlp/trl/__init__.py
+++ b/paddlenlp/trl/__init__.py
@@ -14,6 +14,7 @@
from ..transformers.dpo_criterion import AutoDPOCriterion, DPOCriterion
from ..transformers.kto_criterion import KTOCriterion
+from .dislora_trainer import *
from .dpo_auto_trainer import DPOAutoTrainer
from .dpo_trainer import DPOTrainer
from .embedding_trainer import EmbeddingTrainer
diff --git a/paddlenlp/trl/dislora_trainer.py b/paddlenlp/trl/dislora_trainer.py
new file mode 100644
index 000000000000..7a70e9bed125
--- /dev/null
+++ b/paddlenlp/trl/dislora_trainer.py
@@ -0,0 +1,94 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# dislora_trainer.py
+
+import paddle
+
+from .sft_trainer import SFTTrainer
+
+
+class DisLoRATrainer(SFTTrainer):
+ """
+ A specialized SFTTrainer that incorporates DisLoRA's orthogonal constraint loss.
+
+ This trainer extends the base SFTTrainer by overriding the compute_loss method
+ to add an orthogonal regularization term, which is a key component of the DisLoRA
+ method.
+ """
+
+ def _calc_ortho_loss(self, model):
+ """Calculate the orthogonal constraint loss of DisLoRA"""
+
+ ortho_loss = 0.0
+ den = 0
+
+ for name, param in model.named_parameters():
+ if "Direc_Ur" in name and "weight" in name:
+ u = param
+ iu = paddle.eye(u.shape[0], dtype=u.dtype)
+ u_loss = paddle.norm(u @ u.T - iu, p="fro")
+ ortho_loss += u_loss
+ den += 1
+
+ elif "Direc_Vhr" in name and "weight" in name:
+ vh = param
+ ivh = paddle.eye(vh.shape[1], dtype=vh.dtype)
+ vh_loss = paddle.norm(vh.T @ vh - ivh, p="fro")
+ ortho_loss += vh_loss
+ den += 1
+
+ if den > 0:
+ return ortho_loss / den
+ else:
+ return None
+
+ def compute_loss(self, model, inputs, return_outputs=False):
+ """Override compute_loss to add DisLoRA orthogonal regularization"""
+
+ result = super().compute_loss(model, inputs, return_outputs=False)
+
+ if isinstance(result, tuple):
+ loss = result[0]
+ outputs = result[1] if len(result) > 1 else None
+ else:
+ loss = result
+ outputs = None
+
+ if isinstance(loss, tuple):
+ loss = loss[0]
+
+ if hasattr(self.args, "dislora_ortho_lambda") and self.args.dislora_ortho_lambda > 0:
+ ortho_loss = self._calc_ortho_loss(model)
+
+ if ortho_loss is not None and loss is not None:
+
+ if loss.numel() > 1:
+ loss = loss.mean()
+ if ortho_loss.numel() > 1:
+ ortho_loss = ortho_loss.mean()
+
+ if abs(self.args.dislora_ortho_lambda - 1.0) < 1e-6:
+
+ with paddle.no_grad():
+ ratio = ortho_loss / (loss + 1e-8)
+ alpha_task = paddle.exp(-ratio) / (paddle.exp(-ratio) + paddle.exp(-1 / ratio))
+ alpha_ortho = 1.0 - alpha_task
+
+ loss = alpha_task * loss + alpha_ortho * ortho_loss
+ else:
+
+ loss = loss + self.args.dislora_ortho_lambda * ortho_loss
+
+ return (loss, outputs) if return_outputs else loss
diff --git a/paddlenlp/trl/sft_trainer.py b/paddlenlp/trl/sft_trainer.py
index fd1d6ce6be65..8466bfc7abba 100644
--- a/paddlenlp/trl/sft_trainer.py
+++ b/paddlenlp/trl/sft_trainer.py
@@ -419,70 +419,3 @@ def ptq_loop(
self.prediction_step(model=self.model, inputs=inputs, prediction_loss_only=True, ignore_keys=None)
if max_eval_iters > 0 and step >= max_eval_iters - 1:
break
-
- def _calc_ortho_loss(self, model):
- """Calculate the orthogonal constraint loss of DisLoRA"""
- import paddle
-
- ortho_loss = 0.0
- den = 0
-
- for name, param in model.named_parameters():
- if "Direc_Ur" in name and "weight" in name:
- u = param
- iu = paddle.eye(u.shape[0], dtype=u.dtype)
- u_loss = paddle.norm(u @ u.T - iu, p="fro")
- ortho_loss += u_loss
- den += 1
-
- elif "Direc_Vhr" in name and "weight" in name:
- vh = param
- ivh = paddle.eye(vh.shape[1], dtype=vh.dtype)
- vh_loss = paddle.norm(vh.T @ vh - ivh, p="fro")
- ortho_loss += vh_loss
- den += 1
-
- if den > 0:
- return ortho_loss / den
- else:
- return None
-
- def compute_loss(self, model, inputs, return_outputs=False):
- """Override compute_loss to add DisLoRA orthogonal regularization"""
- import paddle
-
- result = super().compute_loss(model, inputs, return_outputs=False)
-
- if isinstance(result, tuple):
- loss = result[0]
- outputs = result[1] if len(result) > 1 else None
- else:
- loss = result
- outputs = None
-
- if isinstance(loss, tuple):
- loss = loss[0]
-
- if hasattr(self.args, "dislora_ortho_lambda") and self.args.dislora_ortho_lambda > 0:
- ortho_loss = self._calc_ortho_loss(model)
-
- if ortho_loss is not None and loss is not None:
-
- if loss.numel() > 1:
- loss = loss.mean()
- if ortho_loss.numel() > 1:
- ortho_loss = ortho_loss.mean()
-
- if abs(self.args.dislora_ortho_lambda - 1.0) < 1e-6:
-
- with paddle.no_grad():
- ratio = ortho_loss / (loss + 1e-8)
- alpha_task = paddle.exp(-ratio) / (paddle.exp(-ratio) + paddle.exp(-1 / ratio))
- alpha_ortho = 1.0 - alpha_task
-
- loss = alpha_task * loss + alpha_ortho * ortho_loss
- else:
-
- loss = loss + self.args.dislora_ortho_lambda * ortho_loss
-
- return (loss, outputs) if return_outputs else loss
From 97ee1c29c6201f230fab6e73d05a8f3f4ff14867 Mon Sep 17 00:00:00 2001
From: Pioneer-wxh <2274246074@qq.com>
Date: Tue, 21 Oct 2025 17:57:55 +0800
Subject: [PATCH 3/4] =?UTF-8?q?=E2=80=9Creadme=E6=96=87=E4=BB=B6=E4=BF=AE?=
=?UTF-8?q?=E6=94=B9=E2=80=9D?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.gitignore | 3 +--
docs/zh/llm/alignment/ppo/README.md | 1 -
llm/docs/finetune.md | 30 ++++++++++++++++++++++++++++-
3 files changed, 30 insertions(+), 4 deletions(-)
delete mode 120000 docs/zh/llm/alignment/ppo/README.md
diff --git a/.gitignore b/.gitignore
index 47510447a842..5f29895be914 100644
--- a/.gitignore
+++ b/.gitignore
@@ -140,5 +140,4 @@ autogen/
#fp8
ops/csrc/fp8/deep_gemm/include/cutlass
ops/csrc/fp8/deep_gemm/include/cute
-.ccls-cache
-llm/log
+.ccls-cache
\ No newline at end of file
diff --git a/docs/zh/llm/alignment/ppo/README.md b/docs/zh/llm/alignment/ppo/README.md
deleted file mode 120000
index 6547f8485c37..000000000000
--- a/docs/zh/llm/alignment/ppo/README.md
+++ /dev/null
@@ -1 +0,0 @@
-../../../../../llm/alignment/ppo/README.md
\ No newline at end of file
diff --git a/llm/docs/finetune.md b/llm/docs/finetune.md
index 5cc07822f1b7..94a92776ae3b 100644
--- a/llm/docs/finetune.md
+++ b/llm/docs/finetune.md
@@ -177,7 +177,33 @@ python merge_lokr_params.py \
- `device`: 运行环境,默认为 gpu。
-#### 3.4.4 ReFT
+#### 3.4.5 DisLoRA
+```
+# 单卡DisLoRA
+python run_finetune.py ./config/llama/dislora_argument.json
+
+# 多卡DisLoRA(暂不支持张量模型并行)
+python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_finetune.py ./config/llama/dislora_argument.json
+```
+为了后续的**压缩**和**静态图推理**方便,我们提供 DisLoRA 参数合并脚本,可以将 DisLoRA 参数合并到主干模型并保存相应的权重。
+```
+python merge_dislora_params.py \
+ --model_name_or_path ./base_model \
+ --dislora_path ./checkpoints/dislora_ckpts \
+ --merge_dislora_model_path ./checkpoints/dislora_merge \
+ --device "gpu" \
+ --low_gpu_mem True
+```
+
+