Skip to content

FSoft-AI4Code/Visual-Reasonser

Repository files navigation

RL Infra for Agentic Training

This repository contains code for training reinforcement learning agents that can perform agentic thinking-code chain. The system supports multimodal inputs (text and images) and Qwen-VL-2.5

Features

  • Multi-turn agent interactions with code execution
  • Parallel batched code execution for improved performance
  • Support for image generation and processing in code cells
  • Agentic PPO, GRPO with info masking for policy optimization
  • Distributed training with model parallelism
  • Multi-modal Context Parallelism via Ulysses
  • SGLang and VLLM backend for parallel rollout
  • Dynamic Difficulty Sampler for Curriculum Learning
  • Multi-node training with Infiniband
  • Turn-aware credit assignment (simple masked mean loss modification to assign turn-level reward)
  • Async RL (refer to this: https://arxiv.org/abs/2503.18929)

Installation

environment

# Create and activate conda environment
conda create -n vrl python=3.9
conda activate vrl

# Install package in development mode
pip install -e .

# Install dependencies
pip install -r requirements.txt

Quick start

Run RL training with Qwen-VL-2.5-7B:

conda activate vrl
bash train_ppo.sh

(2) Notes, should install vllm>=0.7.3 and latest transformers.

Data preparation ARC (use 20k samples and public ARC eval)

git clone https://github.com/huyphan168/ARC-Fork
cd ARC-Fork && python dataset_converter.py --output_dir '$YOUR_VRL_PATH/data/arc'

Instruction

The design of Verl depends on two things:

  1. Group worker process handled by Ray (can be either actor (sampling), critic, reward model or reference). each group worker process can be offload to CPU or interleaved with each other for each step of the PPO loop to avoid OOM.

  2. The data flow between these processes is defined as DataProto protocal and handled by NCCL, basically DataProto has two parts: non_tensor (like image objects) and tensor (input_ids, attention_mask, etc). The non_tensor part is a numpy array of object with length equals to batch size.

Structure

The dataset is prepared at __getitem__ method in verl/utils/dataset/rl_dataset.py

The main loop of PPO or GRPO training is in verl/trainer/ppo/ray_trainer.py under class of RayPPOTrainer fit method

  • Dataset preparation: verl/utils/dataset/rl_dataset.py
  • PPO implementation: verl/trainer/ppo/ray_trainer.py
  • Reward calculation: verl/utils/reward_score/geo3k.py
  • Actor training: verl.workers.fsdp_workers.py

IMPORTANT NOTE:

Each of these worker process is just a wrapper of known components, for example

  • verl/workers/actor is a wrapper of FSDP and basic pytorch training loop of a LLM
  • verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py is a wrapper of VLLM engine (version > 7.3.0)

VLLM model weight synchronization is where the VLLM sync the model weights after each actor training round.

PPO part

def fit(self):
    """
    The training loop of PPO.
    The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
    The light-weight advantage computation is done on the driver process.
    """

Generally a PPO loop contains follwoing steps:

  1. sampling trajectories via generation_manager.run_llm_loop
  2. compute log prob of the trajectory self.actor_rollout_wg.compute_log_prob(final_gen_batch_output)
  3. compute log prob of the trajectory under reference policy ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
  4. compute values if we're using PPO values = self.critic_wg.compute_values(batch)
  5. compute reward reward_tensor = self.rm_wg.compute_rm_score(batch), here we use hardcode reward like in verl/utils/reward_score/geo3k.py
  6. compute advantage batch = compute_advantage(batch,
  7. training ciritc model critic_output = self.critic_wg.update_critic(batch)
  8. training actor and sync model weight with VLLM actor_output = self.actor_rollout_wg.update_actor(batch). You can refer to def update_actor(self, data: DataProto): in verl.workers.fsdp_workers.py

So the most heavy part is sampling and training actor. Training actor have two strategies: FSDP or Megatron. FSDP is the first choice. Training actor also uses Ulysses deepspeed for context paralelism, I think we dont need to change much here since it use token embeds to chunk rather than dataside (need to confirm since I havent used this).

Agent part

  • self.actor_rollout_wg is basically a VLLM engine to do sampling and rollout LLMGenerationManager is a basically LLM loop (agent) use actor_rollout_wg to do sampling and manage multiple steps sampling then executing code.

  • _update_rolling_state method of LLMGenerationManager is where we update the current sequence with response and observation.

  • batch_execute_code method of LLMGenerationManager will execute in batches all the code that agent produces, supporting output image as well.

  • non_tensor part of our domain is:

non_tensors={'multi_modal_data': new_multi_modal_data, 'multi_modal_inputs': new_multi_modal_inputs}
  • Basically agent part mostly update the current DataProto with newly added observations and handle multi-modal inputs and images.

Arch part

The final part is model architecture: defined in verl/models/transformers/qwen2_vl.py, here I also extend the Qwen-VL to be a token classifier for critic model.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published