Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ Check out our GRM series below, which are evlauated on [reward-bench](https://hu



| Model | Average | Chat | Chat Hard | Safety | Reasoning |
| Model | Average | Chat | Chat Hard | Safety | Reasoning |
|:-------------------------:|:-------------:|:---------:|:---------:|:--------:|:-----------:|
|[GRM_Llama3.1_8B_rewardmodel-ft](https://huggingface.co/Ray2333/GRM_Llama3.1_8B_rewardmodel-ft)**(8B)**| 92.6|95.0 |87.7|91.4|96.4|
|[GRM_Llama3.1_8B_rewardmodel-ft](https://huggingface.co/Ray2333/GRM_Llama3.1_8B_rewardmodel-ft)**(8B)**| 92.6|95.0 |87.7|91.4|96.4|
|[GRM-Llama3-8B-rewardmodel-ft](https://huggingface.co/Ray2333/GRM-Llama3-8B-rewardmodel-ft)**(8B)**|91.5|95.5|86.2|90.8|93.6|
|[GRM-Llama3.2-3B-rewardmodel-ft](https://huggingface.co/Ray2333/GRM-Llama3.2-3B-rewardmodel-ft)**(3B)**|90.9|91.6|84.9|92.7|94.6|
| [GRM-gemma2-2B-rewardmodel-ft](https://huggingface.co/Ray2333/GRM-gemma2-2B-rewardmodel-ft) **(2B)**| 88.4 | 93.0 | 77.2 | 92.2 | 91.2 |
Expand All @@ -22,8 +22,8 @@ Check out our GRM series below, which are evlauated on [reward-bench](https://hu
|[GRM-Gemma-2B-rewardmodel-ft](https://huggingface.co/Ray2333/GRM-Gemma-2B-rewardmodel-ft) **(2B)**| 84.7 | 89.4 | 75.2 | 85.5 | 88.8 |
| [GRM-Gemma2-2B-sftreg](https://huggingface.co/Ray2333/GRM-Gemma2-2B-sftreg)**(2B)** | 81.0 | 97.2 | 59.6 | 86.9 | 80.3 |
| openai/gpt-4o-2024-05-13 | 84.6| 96.6 | 70.4 | 86.5 | 84.9 |
| [GRM-Gemma-2B-sftreg](https://huggingface.co/Ray2333/GRM-Gemma-2B-sftreg)**(2B)** | 75.3 | 95.5 | 48.7 | 80.0 | 76.8 |
| [Gemma-2B-rewardmodel-baseline](https://huggingface.co/Ray2333/Gemma-2B-rewardmodel-baseline)**(2B)** | 73.7 | 94.1 | 46.1 | 79.6 | 75.0 |
| [GRM-Gemma-2B-sftreg](https://huggingface.co/Ray2333/GRM-Gemma-2B-sftreg)**(2B)** | 75.3 | 95.5 | 48.7 | 80.0 | 76.8 |
| [Gemma-2B-rewardmodel-baseline](https://huggingface.co/Ray2333/Gemma-2B-rewardmodel-baseline)**(2B)** | 73.7 | 94.1 | 46.1 | 79.6 | 75.0 |



Expand All @@ -35,7 +35,7 @@ We also evaluated the GRM series using [PPE](https://github.com/lmarena/PPE/tree
|[GRM-llama3-8B-sftreg](https://huggingface.co/Ray2333/GRM-llama3-8B-sftreg)**(8B)**| 62.7 | 66.6 | 60.4| 55.6| 70.9| 59.5 | 63.4|
|[GRM-Llama3-8B-rewardmodel-ft](https://huggingface.co/Ray2333/GRM-Llama3-8B-rewardmodel-ft)**(8B)**| 61.4 | 64.2 | 59.6 | 56.2 | 72.3 | 53.3 | 62.5 |
|[GRM-llama3.2-3B-sftreg](https://huggingface.co/Ray2333/GRM-llama3.2-3B-sftreg)**(3B)**| 61.3 |63.9 |58.7 | 55.6| 74.7| 53.1 | 62.0 |
| ArmoRM-Llama3-8B-v0.1 | 61.2 | 66.5 | 58.4 | 57.0 | 70.7 | 54.2 | 60.6|
| ArmoRM-Llama3-8B-v0.1 | 61.2 | 66.5 | 58.4 | 57.0 | 70.7 | 54.2 | 60.6|
|Skywork-Reward-Llama-3.1-8B | 61.0 | 64.3 | 61.5 | 56.5 | 69.7 | 51.6 | 62.4|
|Nemotron-4-340B-Reward | 60.4| 69.7 | 62.7 | 56.6 | 65.1 | 49.2 | 59.3 |
|[GRM-Llama3.2-3B-rewardmodel-ft](https://huggingface.co/Ray2333/GRM-Llama3.2-3B-rewardmodel-ft)**(3B)**| 59.2 | 62.2 | 57.4 | 56.1 | 72.4 | 46.2 | 60.8 |
Expand All @@ -44,7 +44,7 @@ We also evaluated the GRM series using [PPE](https://github.com/lmarena/PPE/tree



## Usage
## Usage
First set the environment variable.
```
export HF_HOME='your HF token'
Expand Down Expand Up @@ -118,6 +118,6 @@ sh train_ppo_ensemble.sh
```

## Acknowledgment
This repo is built upon [transformers](https://github.com/huggingface/transformers) and [trl](https://github.com/huggingface/trl), with also inspiration from [RLHFlow](https://github.com/RLHFlow/RLHF-Reward-Modeling).
This repo is built upon [transformers](https://github.com/huggingface/transformers) and [trl](https://github.com/huggingface/trl), with also inspiration from [RLHFlow](https://github.com/RLHFlow/RLHF-Reward-Modeling).


14 changes: 7 additions & 7 deletions reward_models/grm_reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,18 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
label_rejected_paded = torch.tensor(feature["label_rejected"].tolist() + [self.label_pad_token_id] * (paded_length - len(feature["label_rejected"])) , dtype=torch.int64)
label_paded.extend([label_chosen_paded.view(1, -1), label_rejected_paded.view(1, -1)])
label_paded = torch.concatenate(label_paded, dim=0)

batch = {
"input_ids": batch["input_ids"],
"attention_mask": batch["attention_mask"],
"return_loss": True,
"label": label_paded,
"label": label_paded,
}
return batch



class GRMRewardTrainer(RewardTrainer):
class GRMRewardTrainer(RewardTrainer):
def __init__(self, **kwargs):
self.reference_free = kwargs.pop('reference_free', True)
self.reference_model = kwargs.pop('reference_model', None)
Expand All @@ -77,7 +77,7 @@ def __init__(self, **kwargs):


def get_batch_logps(
self,
self,
logits: torch.FloatTensor,
labels: torch.LongTensor,
average_log_prob: bool = False,
Expand All @@ -104,10 +104,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
if not self.reference_free:
with torch.no_grad():
ref_logits = self.reference_model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])[0]

bsz = rewards.size(0)
jidx = torch.arange(0, bsz, 2) # chosen_ids
kidx = jidx + 1 # rejected_ids
kidx = jidx + 1 # rejected_ids
reward_loss = -nn.functional.logsigmoid(rewards[jidx] - rewards[kidx]).mean()

## text-generation regularization
Expand All @@ -121,7 +121,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
else:
dpo_loss = -F.logsigmoid(self.beta * (pi_logratios)).mean()
else:
pi_logratios = logps[jidx] - logps[kidx]
pi_logratios = logps[jidx] - logps[kidx]
if self.reference_free or self.sft_only:
ref_logratios = torch.tensor(0.0)
else:
Expand Down
14 changes: 7 additions & 7 deletions reward_models/grm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, config, **kwargs):
for i in range(num_layers):
module_lis.extend([nn.Linear(input_neurons, num_neurons), nn.ReLU()])
input_neurons = num_neurons

module_lis.append(nn.Linear(num_neurons, num_output))
self.summary = nn.Sequential(*module_lis)
self.flatten = nn.Flatten()
Expand Down Expand Up @@ -137,7 +137,7 @@ def forward(
last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device)
elif not hasattr(self.v_head.summary, 'weight') and (last_hidden_state.device != self.v_head.summary[0].weight.device):
last_hidden_state = last_hidden_state.to(self.v_head.summary[0].weight.device)

# use the last token value as reward
last_index = attention_mask.sum(dim=-1) - 1
value = self.v_head(last_hidden_state).squeeze(-1)[torch.arange(len(last_hidden_state)), last_index]
Expand All @@ -164,7 +164,7 @@ def push_to_hub(self, *args, **kwargs):
setattr(self.pretrained_model, "v_head", self.v_head)
return self.pretrained_model.push_to_hub(*args, **kwargs)



def post_init(self, state_dict):
r"""
Expand Down Expand Up @@ -203,7 +203,7 @@ def set_device_hook(module, input, outputs):
self.register_forward_hook(set_device_hook)

self.is_sequential_parallel = True

@classmethod
def register_for_auto_class(cls, auto_class="AutoModel"):
if not isinstance(auto_class, str):
Expand Down Expand Up @@ -234,7 +234,7 @@ def load_model_withhead(model_name, peft_name, tokenizer, device, \

if 'Mistral' not in model_name:
model_config['attn_implementation'] = "flash_attention_2"

model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name, **model_config)
model.pretrained_model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id
Expand All @@ -255,7 +255,7 @@ def load_model_withhead(model_name, peft_name, tokenizer, device, \
loaded_state_dict = torch.load(os.path.join(peft_name, "pytorch_model.bin"))
missing, unexpected = model.base_model.model.pretrained_model.load_state_dict(loaded_state_dict, strict=False)
missing, unexpected = model.base_model.model.load_state_dict(loaded_state_dict, strict=False)

if hasattr(model, 'merge_and_unload'):
model = model.merge_and_unload()
return model
Expand All @@ -266,7 +266,7 @@ def model_withhead_forward(model, input_ids, attention_mask, device, forward_typ
elif forward_type == 'dpo':
res = model(input_ids.to(device), attention_mask=attention_mask.to(device))
if len(res) == 3:
logits, _, _ = res
logits, _, _ = res
else:
logits = res.logits
if logits.shape[:-1] != labels.shape:
Expand Down
24 changes: 12 additions & 12 deletions reward_models/load_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# for vanilla chosen and reject style dataset, such as dendrydong/preference_700K
def build_dataset(data_path, tokenizer, split='train', size=None, model_name=''):
ds = load_dataset(data_path, split=split)

if size is not None:
ds = ds.select(range(0, size))

Expand Down Expand Up @@ -40,7 +40,7 @@ def formatting_func(example):
"input_ids_rejected": tokens_rejected["input_ids"][0], "attention_mask_rejected": tokens_rejected["attention_mask"][0],
}

ds = ds.map(formatting_func, batched=False, num_proc=10)
ds = ds.map(formatting_func, batched=False, num_proc=10)
remove_columns = []
for col in ds.column_names:
if 'input' not in col and 'attention' not in col and 'label' not in col:
Expand All @@ -57,15 +57,15 @@ def build_dataset_UF(data_path, tokenizer, split='train', size=None, mode='', mo
ds = load_dataset(data_path, 'all', split=split)
except:
ds = load_dataset(data_path, split=split)

# filter data with the same rating
ds = ds.filter(lambda example: example['conv_A_rating'] != example['conv_B_rating'], num_proc=30)

if len(mode):
if mode == '40k' or mode == '40K':
ds = ds.select(range(0, len(ds), 20))
ds = ds.select(range(0, len(ds), 20))
elif mode == '400k' or mode == '400K':
ds = ds.select(range(0, len(ds), 2))
ds = ds.select(range(0, len(ds), 2))

if size is not None:
ds = ds.select(range(0, size))
Expand All @@ -80,11 +80,11 @@ def formatting_func(example):
chosen_messages = example['conv_B']
rejected_messages = example['conv_A']
margin = example['conv_B_rating'] - example['conv_A_rating']

if 'summarize' in example['source']:
chosen_messages[0]['content'] = 'Generate one-sentence summary for the following post: ' + chosen_messages[0]['content'].strip()
rejected_messages[0]['content'] = 'Generate one-sentence summary for the following post: ' + rejected_messages[0]['content'].strip()

prompt_plus_chosen_response = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
prompt_plus_rejected_response = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
tokens_chosen = tokenizer.encode_plus(prompt_plus_chosen_response, **kwargs)
Expand All @@ -108,9 +108,9 @@ def formatting_func(example):
return {
"input_ids_chosen": tokens_chosen["input_ids"][0], "attention_mask_chosen": tokens_chosen["attention_mask"][0],
"input_ids_rejected": tokens_rejected["input_ids"][0], "attention_mask_rejected": tokens_rejected["attention_mask"][0],
"margin": margin,
"margin": margin,
}


ds = ds.map(formatting_func, batched=False, num_proc=10)
# ds = ds.filter(lambda x: len(x["input_ids_chosen"]) <= script_args.max_length and len(x["input_ids_rejected"]) <= script_args.max_length, num_proc=30)
Expand Down Expand Up @@ -161,22 +161,22 @@ def formatting_func(example):
"input_ids_rejected": tokens_rejected["input_ids"][0], "attention_mask_rejected": tokens_rejected["attention_mask"][0],
}

ds = ds.map(formatting_func, batched=False, num_proc=10)
ds = ds.map(formatting_func, batched=False, num_proc=10)
ds.set_format(type="torch")
return ds


def load_train_eval_dataset(data_path, tokenizer, size=None, mode='', model_name=''):
if 'Unified' in data_path:
# mode is only used for loading training data
train_dataset = build_dataset_UF(data_path, tokenizer, split='train', size=size, mode=mode, model_name=model_name)
train_dataset = build_dataset_UF(data_path, tokenizer, split='train', size=size, mode=mode, model_name=model_name)
eval_dataset = build_dataset_UF(data_path, tokenizer, split='val', model_name=model_name)
elif 'Skywork' in data_path:
dataset = build_dataset_SK(data_path, tokenizer, split='train', size=size, model_name=model_name)
dataset_split = dataset.train_test_split(test_size=0.005)
train_dataset, eval_dataset = dataset_split['train'], dataset_split['test']
else:
dataset = build_dataset(data_path, tokenizer, split='train', size=size, model_name=model_name)
dataset = build_dataset(data_path, tokenizer, split='train', size=size, model_name=model_name)
dataset_split = dataset.train_test_split(test_size=0.01)
train_dataset, eval_dataset = dataset_split['train'], dataset_split['test']
return train_dataset, eval_dataset
4 changes: 2 additions & 2 deletions reward_models/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
rewards_k = rewards[kidx]

if self.loss_type == 'bt':
loss = - nn.functional.logsigmoid(rewards_j - rewards_k).mean()
loss = - nn.functional.logsigmoid(rewards_j - rewards_k).mean()
elif self.loss_type == 'pos_reg':
loss = - nn.functional.logsigmoid(rewards_j - rewards_k).mean() - self.weight_ratio * nn.functional.logsigmoid(rewards_j.mean())
elif self.loss_type == 'margin':
loss = -nn.functional.logsigmoid(rewards_j - rewards_k - torch.tensor(inputs["margin"], device=inputs["margin"][0].device).view(-1,1)).mean()
elif self.loss_type == 'labelsmooth':
loss = - (1-self.weight_ratio) * nn.functional.logsigmoid(rewards_j - rewards_k).mean() - self.weight_ratio * nn.functional.logsigmoid(rewards_k - rewards_j).mean()
loss = - (1-self.weight_ratio) * nn.functional.logsigmoid(rewards_j - rewards_k).mean() - self.weight_ratio * nn.functional.logsigmoid(rewards_k - rewards_j).mean()
else:
raise NotImplementedError

Expand Down
14 changes: 7 additions & 7 deletions reward_models/run_grm_reward_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
@dataclass
class ScriptArguments:
# training args
per_device_train_batch_size: Optional[int] = field(default=1)
per_device_train_batch_size: Optional[int] = field(default=1)
gradient_accumulation_steps: Optional[int] = field(default=16)
learning_rate: Optional[float] = field(default=1e-5)
num_train_epochs: Optional[int] = field(default=2, metadata={"help": "The number of training epochs for the reward model."})
optim: Optional[str] = field(default="adamw_hf", metadata={"help": "The optimizer to use."})
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "The lr scheduler"},)
max_length: Optional[int] = field(default=1024)
max_length: Optional[int] = field(default=1024)
gradient_checkpointing: Optional[bool] = field(default=True)
bf16: Optional[bool] = field(default=True)
attn_implementation: Optional[str] = field(default="flash_attention_2")
Expand Down Expand Up @@ -64,9 +64,9 @@ class ScriptArguments:
reference_free: Optional[bool] = field(default=True)
sft_only: Optional[bool] = field(default=True)
no_logsigmoid_sft: Optional[bool] = field(default=False)







parser = HfArgumentParser(ScriptArguments)
Expand All @@ -77,7 +77,7 @@ class ScriptArguments:
else:
output_name = f"{script_args.log_dir}/{model_name_split}_{script_args.wandb_name}_len{script_args.max_length}_fulltrain_{script_args.learning_rate}_data{script_args.dataset.split('/')[-1]}"

device = Accelerator().local_process_index
device = Accelerator().local_process_index

training_args = TrainingArguments(
output_dir=os.path.join(output_name, 'logs'),
Expand All @@ -90,7 +90,7 @@ class ScriptArguments:
save_strategy=script_args.save_strategy,
save_steps=script_args.save_steps,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
gradient_checkpointing=script_args.gradient_checkpointing,
gradient_checkpointing=script_args.gradient_checkpointing,
bf16=script_args.bf16,
logging_strategy="steps",
logging_steps=10,
Expand Down Expand Up @@ -136,7 +136,7 @@ class ScriptArguments:


model = AutoModelForCausalLMWithValueHead.from_pretrained(
script_args.base_model, device_map=device,
script_args.base_model, device_map=device,
torch_dtype=torch.bfloat16,
**model_params,
)
Expand Down
16 changes: 8 additions & 8 deletions reward_models/run_reward_models_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
@dataclass
class ScriptArguments:
# training args
per_device_train_batch_size: Optional[int] = field(default=1)
per_device_train_batch_size: Optional[int] = field(default=1)
gradient_accumulation_steps: Optional[int] = field(default=16)
learning_rate: Optional[float] = field(default=1e-5)
num_train_epochs: Optional[int] = field(default=2, metadata={"help": "The number of training epochs for the reward model."})
optim: Optional[str] = field(default="adamw_hf", metadata={"help": "The optimizer to use."})
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "The lr scheduler"},)
max_length: Optional[int] = field(default=1024)
max_length: Optional[int] = field(default=1024)
gradient_checkpointing: Optional[bool] = field(default=True)
bf16: Optional[bool] = field(default=True)
attn_implementation: Optional[str] = field(default="flash_attention_2")
Expand Down Expand Up @@ -56,7 +56,7 @@ class ScriptArguments:
save_strategy: Optional[str] = field(default="epoch")
save_steps: Optional[int] = field(default=1000)
debug: Optional[bool] = field(default=False, metadata={'help': 'if debug=True, only train with 100 samples'})



parser = HfArgumentParser(ScriptArguments)
Expand All @@ -67,7 +67,7 @@ class ScriptArguments:
else:
output_name = f"{script_args.log_dir}/{model_name_split}_{script_args.wandb_name}_len{script_args.max_length}_fulltrain_{script_args.learning_rate}_data{script_args.dataset.split('/')[-1]}"

device = Accelerator().local_process_index
device = Accelerator().local_process_index

training_args = TrainingArguments(
output_dir=os.path.join(output_name, 'logs'),
Expand All @@ -80,7 +80,7 @@ class ScriptArguments:
save_strategy=script_args.save_strategy,
save_steps=script_args.save_steps,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
gradient_checkpointing=script_args.gradient_checkpointing,
gradient_checkpointing=script_args.gradient_checkpointing,
bf16=script_args.bf16,
logging_strategy="steps",
logging_steps=10,
Expand Down Expand Up @@ -117,17 +117,17 @@ class ScriptArguments:
model_params = {}

model = AutoModelForSequenceClassification.from_pretrained(
script_args.base_model, num_labels=1, device_map=device,
script_args.base_model, num_labels=1, device_map=device,
torch_dtype=torch.bfloat16,
**model_params
)

if script_args.freeze_pretrained:
# for frozon baseline
mlp_layer = nn.Sequential(
nn.Linear(model.config.hidden_size, 1024, dtype=torch.bfloat16),
nn.Linear(model.config.hidden_size, 1024, dtype=torch.bfloat16),
nn.ReLU(),
nn.Linear(1024, 1, dtype=torch.bfloat16)
nn.Linear(1024, 1, dtype=torch.bfloat16)
)
mlp_layer.to(device)
# Replace the classifier with the MLP
Expand Down
2 changes: 1 addition & 1 deletion reward_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def compute_metrics(eval_pred):
def grm_compute_metrics(eval_pred):
rewards = eval_pred.label_ids
reward_accuracy = (rewards[:, 0] > rewards[:, 1]).mean()

predictions = eval_pred.predictions
accuracy = (predictions[:, 0] > predictions[:, 1]).mean()
return {
Expand Down
Loading