Skip to content

Request for details and assistance on PPO Experiments with SFT+PPO training #16

@roshansridhar

Description

@roshansridhar

Hello Developers,

Firstly, I would like to thank you for the excellent work on this repository and for sharing the plots on other issues. I'm currently utilizing your library to train a model using sft+ppo, and I've successfully replicated the sft experiment as per the results shared ContextualAI/HALOs/issues/13.

However, I'm experiencing negligible improvement with the PPO part of the training. Could you provide the details of your PPO experiments? I noted in a previous comment that the preferential tuning was run significantly longer, so I adjusted the PPO epochs to 3 in my experiments. Are these adjustments in line with what was done in your experiments? Additionally, could you elaborate on how and when to decide which checkpoint to use for downstream tasks, especially for PPO, DPO, and KTO scenarios?

Link to my plots: l7b_ppo_0419
Screenshot 2024-04-22 at 11 11 32 AM

Here are the commands I used for my experiments:

  • SFT Training Command
python train.py loss=sft model=llama7b datasets=[shp,hh,oasst] exp_name=l7b_sft_0416 mode=train ++cache_dir=/data/models wandb.project=l7b_sft_0416
  • PPO Training (3 Epochs)
# Updated n_epochs in config.yaml to 3
python train.py loss=ppo model=llama7b datasets=[shp,hh,oasst] exp_name=l7b_ppo_0419 mode=train ++cache_dir=./data/models ++model.load_from=l7b_sft_0416/LATEST/policy.pt wandb.project=l7b_ppo_0419

Additional Query:

  • When conducting sft training, it calculates train and validation losses using train dataset splits. If I use the same dataset for ppo, how can I ensure that I am not retraining on the train split inadvertently? Furthermore, when using stratified datasets for both sft and preferential tuning, do you recommend holding out different data points for each, and is this approach considered best practice?
  • Based on your plots and results shared, am I correct in understanding that you had a batch size of 32 and conducted 200k steps of sft training, which equates to training on approximately 6.4 million datapoints, DPO on roughly 9.5 million datapoints (300k * 32), and KTO on about 17.6 million datapoints?

Thank you for any guidance or insights you can provide.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions