Skip to content
74 changes: 74 additions & 0 deletions llm/alignment/rl/gsm8k_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the GSM8k dataset to parquet format
"""

import argparse
import os
import re

import datasets


def extract_solution(solution_str):
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
assert solution is not None
final_solution = solution.group(0)
final_solution = final_solution.split("#### ")[1].replace(",", "")
return final_solution


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default="./gsm8k")

args = parser.parse_args()

data_source = "openai/gsm8k"

dataset = datasets.load_dataset(data_source, "main")

train_dataset = dataset["train"]
test_dataset = dataset["test"]

instruction_following = 'Let\'s think step by step and output the final answer after "####".'

# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
question_raw = "<|im_start|>user\n" + example.pop("question")

system_raw = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
)
question = system_raw + question_raw + " " + instruction_following + "<|im_end|>\n<|im_start|>assistant\n"

answer_raw = example.pop("answer")
solution = extract_solution(answer_raw)
data = {
"src": question,
"tgt": solution,
}
return data

return process_fn

train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)

local_dir = args.local_dir

train_dataset.to_json(os.path.join(local_dir, "train.jsonl"), orient="records", lines=True)
test_dataset.to_json(os.path.join(local_dir, "test.jsonl"), orient="records", lines=True)
94 changes: 50 additions & 44 deletions llm/alignment/rl/run_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from paddlenlp.transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForTokenClassification,
AutoTokenizer,
PretrainedConfig,
)
Expand All @@ -59,7 +60,7 @@ def process_args(model_args: ModelArgument, data_args: DataArgument, training_ar
if model_args.reward_server is None:
raise ValueError("Please specify reward_server when use_rm_server is true.")
logger.info(f"Use reward server: {model_args.reward_server} for training.")
if training_args.rl_algorithm == "ppo" and model_args.critic_model_name_or_path is None:
if training_args.rl_algorithm in ["ppo", "vapo"] and model_args.critic_model_name_or_path is None:
raise ValueError("Please specify critic_model_name_or_path when use_rm_server is true.")
else:
if model_args.reward_model_name_or_path is None:
Expand Down Expand Up @@ -134,7 +135,6 @@ def create_actor_models(
)
if not training_args.autotuner_benchmark:
reference_model.set_state_dict(actor_model.state_dict())

actor_tokenizer = AutoTokenizer.from_pretrained(
model_args.actor_model_name_or_path,
model_max_length=data_args.max_length,
Expand Down Expand Up @@ -210,46 +210,43 @@ def create_critic_models(
data_args: DataArgument,
training_args: TrainingArguments,
common_config: Dict,
reward_model,
):
with timers_scope_runtimer("Critic model loading time"):
reward_model_config = reward_model.config
if model_args.critic_model_name_or_path is None:
model_args.critic_model_name_or_path = model_args.reward_model_name_or_path
critic_model = AutoModelForScore.from_config(
reward_model_config,
dtype=training_args.model_dtype,
score_type="critic",
do_normalize=False,
clip_range_value=training_args.clip_range_value,
**common_config,
critic_model_config = AutoConfig.from_pretrained(
model_args.critic_model_name_or_path,
tensor_parallel_output=training_args.tensor_parallel_output,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
dtype=training_args.model_dtype,
recompute=training_args.critic_recompute,
recompute_granularity=model_args.critic_recompute_granularity,
recompute_use_reentrant=training_args.recompute_use_reentrant,
**common_config,
)
LlmMetaConfig.set_llm_config(critic_model_config, training_args)

critic_model_config.max_position_embeddings = data_args.max_length
critic_model_config.use_sparse_head_and_loss_fn = False
critic_model_config.num_labels = 1
critic_model_config.classifier_dropout = 0.0
critic_model_config.hidden_dropout = 0.0
logger.info(f"Loading Critic model with config:\n\t{critic_model_config}\n")

if not training_args.autotuner_benchmark:
critic_model = AutoModelForTokenClassification.from_pretrained(
model_args.critic_model_name_or_path,
config=critic_model_config,
)
if not training_args.autotuner_benchmark:
critic_model.set_state_dict(reward_model.state_dict())
else:
if not training_args.autotuner_benchmark:
critic_model = AutoModelForScore.from_pretrained(
model_args.critic_model_name_or_path,
config=reward_model_config,
score_type="critic",
do_normalize=False,
clip_range_value=training_args.clip_range_value,
**common_config,
)
else:
critic_model = AutoModelForScore.from_config(
reward_model_config,
score_type="critic",
do_normalize=False,
clip_range_value=training_args.clip_range_value,
**common_config,
)
critic_model = AutoModelForTokenClassification.from_config(
critic_model_config,
)

critic_tokenizer = AutoTokenizer.from_pretrained(
model_args.critic_model_name_or_path,
model_max_length=data_args.max_length,
padding_side="left",
tokenizer_alpha=model_args.reward_critic_tokenizer_alpha,
tokenizer_alpha=model_args.critic_tokenizer_alpha,
use_fast=True,
)
if critic_tokenizer.pad_token_id is None:
Expand All @@ -261,16 +258,16 @@ def create_critic_models(
if training_args.eval_mode == "single":
config.tensor_parallel_degree = -1
config.tensor_parallel_rank = 0
with timers_scope_runtimer("Reward critic eval model loading time"):
critic_eval_model = AutoModelForScore.from_config(config)
with timers_scope_runtimer("Critic eval model loading time"):
critic_eval_model = AutoModelForTokenClassification.from_config(config)
else:
critic_eval_model = None

return critic_model, critic_eval_model, critic_tokenizer


def create_rl_dataset(data_args, training_args, tokenizer):
requires_label = True if training_args.use_rm_server else False
requires_label = True if training_args.use_rm_server or training_args.use_rule_reward else False
train_ds = RLHFDataset(
dataset_name_or_path=data_args.train_datasets,
tokenizer=tokenizer,
Expand Down Expand Up @@ -333,15 +330,16 @@ def main():
actor_model, actor_eval_model, reference_model, actor_tokenizer = create_actor_models(
model_args, data_args, training_args, common_config, reshard_controller
)

if not training_args.use_rm_server and model_args.reward_model_name_or_path is not None:
if training_args.use_rule_reward:
reward_model, reward_tokenizer = None, actor_tokenizer
elif not training_args.use_rm_server and model_args.reward_model_name_or_path is not None:
reward_model, reward_tokenizer = create_reward_models(model_args, data_args, training_args, common_config)
else:
reward_model, reward_tokenizer = model_args.reward_server, actor_tokenizer

if training_args.rl_algorithm == "ppo":
if training_args.rl_algorithm in ["ppo", "vapo"]:
critic_model, critic_eval_model, critic_tokenizer = create_critic_models(
model_args, data_args, training_args, common_config, reward_model
model_args, data_args, training_args, common_config
)
else:
critic_model, critic_eval_model, critic_tokenizer = None, None, None
Expand All @@ -354,16 +352,24 @@ def main():
offload_tensor_to_cpu((actor_eval_model, "freeze_model"))
offload_tensor_to_cpu((reference_model, "freeze_model"))

if training_args.rl_algorithm == "ppo":
offload_tensor_to_cpu((reward_model, "freeze_model"))
if training_args.rl_algorithm in ["ppo", "vapo"]:
if not training_args.use_rm_server and not training_args.use_rule_reward:
offload_tensor_to_cpu((reward_model, "freeze_model"))
if critic_eval_model is not None:
offload_tensor_to_cpu((critic_eval_model, "freeze_model"))

# NOTE(gongenlei): release memory_reserved_size to equal to memory_allocated_size
paddle.device.cuda.empty_cache()

def compute_metrics(eval_preds):
accuracy = (eval_preds.predictions == 3).astype("float32").mean().item()
"""
If "use_rm_server" is TRUE, the score ranges from -3 to 3, with 3 being the only correct score (format + result).
If using the "Regularized Matching Function (use_rule_reward=True)" (currently only implemented for the gsm8k dataset), the score ranges from 0 to 1.
"""
if training_args.use_rule_reward:
accuracy = (eval_preds.predictions == 1).astype("float32").mean().item()
else:
accuracy = (eval_preds.predictions == 3).astype("float32").mean().item()
return {"accuracy": accuracy}

try:
Expand All @@ -389,7 +395,7 @@ def compute_metrics(eval_preds):
data_collator=partial(
collate_fn,
pad_token_id=actor_tokenizer.pad_token_id,
requires_label=True if training_args.use_rm_server else False,
requires_label=True if training_args.use_rm_server or training_args.use_rule_reward else False,
max_prompt_len=data_args.max_prompt_len if training_args.balance_batch else None,
), # NOTE: enforce prompt padding to max_prompt_len when using balance_batch
compute_metrics=compute_metrics, # TODO: only used for grpo (kk datasets)
Expand Down
131 changes: 131 additions & 0 deletions llm/config/qwen/ppo_argument.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# RL algorithms
rl_algorithm: "ppo" # The reinforcement learning algorithm used, supported: "ppo", "grpo", "reinforce_plus_plus"

# models
actor_model_name_or_path: "Qwen/Qwen2.5-1.5B-Instruct" # The name or path of the actor model
reward_model_name_or_path: "" # The name or path of the reward model
critic_model_name_or_path: "Qwen/Qwen2.5-1.5B-Instruct" # The name or path of the critic model
use_rm_server: false # Whether to use the reward model server
reward_server: "http://127.0.0.1:8731" # The address of the reward model server
use_rule_reward: True # The reward for gsm8k dataset. If use_rule_reward: use_rm_server = false

# logging
logging_dir: ppo-logs # Directory for logging
logging_steps: 1 # Number of steps between logging
output_dir: "qwen2.5-1.5b-gsm8k-ppo/checkpoints" # Directory for output ckpts
report_to: "visualdl" # Supported reporting options: "all", "wandb", "tensorboard", "visualdl"(default), "none"
wandb_http_proxy: "http://agent.baidu.com:8188" # HTTP proxy for wandb
run_name: "qwen2.5-1.5b-gsm8k-ppo" # Name of the run

# data
train_datasets: "gsm8k/train.jsonl" # Path to the training dataset
eval_datasets: "gsm8k/test.jsonl" # Path to the evaluation dataset
prompt_key: "src" # Key for the prompt in the dataset
response_key: "tgt" # Key for the response in the dataset
dataloader_drop_last: true # Whether to drop the last incomplete batch in the DataLoader
balance_batch: true # Whether to balance batch size across dataset_world_size
use_remove_padding: true # Whether to remove padding tokens in the input

# distributed training args
tensor_parallel_degree: 2 # Degree of tensor parallelism
sequence_parallel: true # Whether to enable sequence parallelism
sharding_parallel_degree: -1 # Degree of sharding parallelism
sharding: "stage1" # Sharding strategy, e.g., "stage1" or "stage2"
sharding_parallel_config: "enable_release_grads" # Configuration for sharding parallelism
pipeline_parallel_degree: 1 # Degree of pipeline parallelism
virtual_pp_degree: 1 # Degree of virtual pipeline parallelism

# rollout args
max_prompt_len: 1024 # Maximum length of the prompt, exceeding which will be automatically truncated
max_dec_len: 512 # Maximum length of the response
min_dec_len: 32 # Minimum length of the response
top_p: 1.0 # Top-p sampling parameter
temperature: 1.0 # Temperature parameter for sampling
repetition_penalty: 1.0 # Repetition penalty parameter
rollout_max_num_seqs: 1024 # The maximum number of sequences that can be processed in a single inference
rollout_quant_type: "" # Quantization type, e.g., "weight_only_int8"

# training args
do_train: true # Whether to perform training
seed: 42 # Random seed for reproducibility
global_batch_size: 256 # Global batch size for training (rollouts = rollout_n * global_batch_size)
global_gen_batch_size: -1 # Global generation batch size for dynamic sampling
global_mini_batch_size: 64 # Mini-batch size for training, default = (global_batch_size * rollout_n * update_iters) // dataset_world_size
rollout_n: 1 # Number of rollouts, set rollout_n = 1 for 'ppo'
update_iters: 1 # Number of training iterations for rollout samples
per_device_logprob_batch_size: 4 # Log probability batch size per device
per_device_reward_batch_size: 2 # Reward batch size per device
per_device_value_batch_size: 2 # Value batch size per device
per_device_train_batch_size: 2 # Training micro batch size per device
# gradient_accumulation_steps: 4 # Gradient accumulation steps (auto-calculated):
num_train_epochs: 15 # Number of training epochs
max_length: 2048 # Maximum length for training, should be larger than max_prompt_len + max_dec_len
adam_beta1: 0.9 # AdamW optimizer beta1
adam_beta2: 0.999 # AdamW optimizer beta2
adam_epsilon: 1e-8 # AdamW optimizer epsilon
max_grad_norm: 1.0 # Maximum gradient norm for clipping
max_steps: -1 # Maximum number of training steps
save_steps: 300 # Number of steps between model saves
save_strategy: "steps" # Strategy for saving models
ignore_save_lr_and_optim: true # Whether to ignore saving learning rate and optimizer state (leave empty if not specified)
disable_tqdm: true # Whether to disable tqdm progress bar

# actor training args
learning_rate: 1e-6 # Learning rate for training
min_learning_rate: 5e-7 # Minimum learning rate
lr_scheduler_type: "cosine" # Learning rate scheduler type
weight_decay: 1e-2 # Weight decay for the AdamW optimizer
warmup_ratio: 0.2 # Number of warmup steps

# critic training args
critic_learning_rate: 1e-5 # Learning rate for critic model
critic_min_learning_rate: 5e-6 # Minimum learning rate for critic model
critic_lr_scheduler_type: "cosine" # Learning rate scheduler type for critic model
critic_weight_decay: 1e-2 # Weight decay for the AdamW optimizer of critic model
critic_warmup_ratio: 0.2 # Number of warmup steps for critic model

# RL args
kl_coeff: 0.0 # KL coefficient
kl_loss_coeff: 0.001 # KL loss coefficient
pg_loss_coeff: 1.0 # Policy gradient loss coefficient
entropy_coeff: 0.001 # Entropy coefficient
clip_range_ratio: 0.2 # The clipping range for ratio between the old and new policy. (PPO algorithm)
clip_range_ratio_low: 0.2 # The clipping range for ratio between the old and new policy. (PPO algorithm)
clip_range_ratio_high: 0.2 # The clipping range for ratio between the old and new policy. (PPO algorithm)
clip_range_score: 10.0 # The clipping range for the output of the score model. The reward is clipped into [-clip_range_score, clip_range_score].
enable_overlong_reward_buffer: false # Whether to enable overlong reward buffer
overlong_reward_buffer: 256 # The length of the overlong reward buffer
overlong_penalty_factor: 1.0 # The penalty factor for overlong reward buffer
clip_range_value: 0.5 # The clipping range for the output of the value model. The value is clipped into [-clip_range_value, clip_range_value].
normalize_reward: false # Whether to normalize reward
normalize_advantage: false # Whether to normalize advantage
dynamic_sampling: false # Whether to use dynamic sampling, which is introcuded in DAPO algorithm https://arxiv.org/abs/2503.14476
max_gen_batches: 2 # Maximum number of generation batches for dynamic sampling
use_fp32_compute: true # Whether to use fp32 to compute xx_log_prob,rewards, advantages and loss

# eval args
do_eval: true # Whether to perform evaluation
per_device_eval_batch_size: 1319 # Evaluation batch size per device
evaluation_strategy: "steps" # Evaluation strategy, e.g., "steps"
eval_steps: 32 # Number of steps between evaluations

# device memory optimization args
use_flash_attention: true # Whether to use fused attention operations
use_fused_rms_norm: true # Whether to use fused RMS norm operations, which needs to install fused_ln in slm/model_zoo/gpt-3/external_ops
use_fused_rope: false # Whether to use fused rope operations
use_fused_head_and_loss_fn: true # Whether to use fused head and loss function
use_fused_linear: true # Whether to use fused linear operations. 像是一个没有用的参数
recompute: false # Whether to enable gradient checkpointing for memory optimization
recompute_use_reentrant: false # Whether to use reentrant recompute
recompute_granularity: "full" # Granularity of recompute
bf16: true # Whether to use mixed precision with bfloat16
fp16_opt_level: "O2" # Optimization level for fp16 and bf16 training
amp_master_grad: false # Whether to use float32 weight gradients for master weights in amp opt level=’O2’
amp_custom_black_list: ["reduce_sum", "softmax_with_cross_entropy", "c_softmax_with_cross_entropy", "elementwise_div", "sin", "cos"] # Custom black list for amp
amp_custom_white_list: ["lookup_table", "lookup_table_v2", "flash_attn", "matmul", "matmul_v2", "fused_gemm_epilogue"] # Custom white list for amp
offload_level: "freeze_model" # Level of model offloading to pinned memory, supported values: freeze_model, train_model, optimizer
release_grads: true # Whether to release gradients
offload_optim: false # Whether to offload optimizer to pinned memory

# benchmark args
skip_profile_timer: false # Whether to skip profiling time
Loading
Loading