This repository contains code for training reinforcement learning agents that can perform agentic thinking-code chain. The system supports multimodal inputs (text and images) and Qwen-VL-2.5
- Multi-turn agent interactions with code execution
- Parallel batched code execution for improved performance
- Support for image generation and processing in code cells
- Agentic PPO, GRPO with info masking for policy optimization
- Distributed training with model parallelism
- Multi-modal Context Parallelism via Ulysses
- SGLang and VLLM backend for parallel rollout
- Dynamic Difficulty Sampler for Curriculum Learning
- Multi-node training with Infiniband
- Turn-aware credit assignment (simple masked mean loss modification to assign turn-level reward)
- Async RL (refer to this: https://arxiv.org/abs/2503.18929)
# Create and activate conda environment
conda create -n vrl python=3.9
conda activate vrl
# Install package in development mode
pip install -e .
# Install dependencies
pip install -r requirements.txt
Run RL training with Qwen-VL-2.5-7B:
conda activate vrl
bash train_ppo.sh
(2) Notes, should install vllm>=0.7.3 and latest transformers.
Data preparation ARC (use 20k samples and public ARC eval)
git clone https://github.com/huyphan168/ARC-Fork
cd ARC-Fork && python dataset_converter.py --output_dir '$YOUR_VRL_PATH/data/arc'
The design of Verl depends on two things:
-
Group worker process handled by Ray (can be either actor (sampling), critic, reward model or reference). each group worker process can be offload to CPU or interleaved with each other for each step of the PPO loop to avoid OOM.
-
The data flow between these processes is defined as
DataProto
protocal and handled by NCCL, basicallyDataProto
has two parts:non_tensor
(like image objects) andtensor
(input_ids
,attention_mask
, etc). Thenon_tensor
part is a numpy array of object with length equals to batch size.
The dataset is prepared at __getitem__
method in verl/utils/dataset/rl_dataset.py
The main loop of PPO or GRPO training is in verl/trainer/ppo/ray_trainer.py
under class of RayPPOTrainer
fit method
- Dataset preparation:
verl/utils/dataset/rl_dataset.py
- PPO implementation:
verl/trainer/ppo/ray_trainer.py
- Reward calculation:
verl/utils/reward_score/geo3k.py
- Actor training:
verl.workers.fsdp_workers.py
IMPORTANT NOTE:
Each of these worker process is just a wrapper of known components, for example
verl/workers/actor
is a wrapper of FSDP and basic pytorch training loop of a LLMverl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
is a wrapper of VLLM engine (version > 7.3.0)
VLLM model weight synchronization is where the VLLM sync the model weights after each actor training round.
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
Generally a PPO loop contains follwoing steps:
- sampling trajectories via
generation_manager.run_llm_loop
- compute log prob of the trajectory
self.actor_rollout_wg.compute_log_prob(final_gen_batch_output)
- compute log prob of the trajectory under reference policy
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
- compute values if we're using PPO
values = self.critic_wg.compute_values(batch)
- compute reward
reward_tensor = self.rm_wg.compute_rm_score(batch)
, here we use hardcode reward like inverl/utils/reward_score/geo3k.py
- compute advantage
batch = compute_advantage(batch,
- training ciritc model
critic_output = self.critic_wg.update_critic(batch)
- training actor and sync model weight with VLLM
actor_output = self.actor_rollout_wg.update_actor(batch)
. You can refer todef update_actor(self, data: DataProto):
inverl.workers.fsdp_workers.py
So the most heavy part is sampling and training actor. Training actor have two strategies: FSDP or Megatron. FSDP is the first choice. Training actor also uses Ulysses deepspeed for context paralelism, I think we dont need to change much here since it use token embeds to chunk rather than dataside (need to confirm since I havent used this).
-
self.actor_rollout_wg
is basically a VLLM engine to do sampling and rolloutLLMGenerationManager
is a basically LLM loop (agent) useactor_rollout_wg
to do sampling and manage multiple stepssampling then executing code
. -
_update_rolling_state
method ofLLMGenerationManager
is where we update the current sequence with response and observation. -
batch_execute_code
method ofLLMGenerationManager
will execute in batches all the code that agent produces, supporting output image as well. -
non_tensor
part of our domain is:
non_tensors={'multi_modal_data': new_multi_modal_data, 'multi_modal_inputs': new_multi_modal_inputs}
- Basically agent part mostly update the current DataProto with newly added observations and handle multi-modal inputs and images.
The final part is model architecture: defined in verl/models/transformers/qwen2_vl.py
, here I also extend the Qwen-VL to be a token classifier for critic model.