Skip to content

Conversation

sglucas
Copy link

@sglucas sglucas commented Jul 23, 2025

Closes #6367

📌 Checklist before creating the PR

  • I have created an issue for this PR for traceability
  • The title follows the standard format: [doc/gemini/tensor/...]: A concise description
  • I have added relevant tags if possible for us to better distinguish different PRs
  • I have installed pre-commit: pip install pre-commit && pre-commit install

🚨 Issue number

Link this PR to your issue with words like fixed to automatically close the linked issue upon merge

e.g. fixed #1234, closed #1234, resolved #1234

📝 What does this PR do?

Summarize your work here.
if you have any plots/diagrams/screenshots/tables, please attach them here.

Add more training models (LLaMA3, Qwen3) and RLHF algorithms (REINFORCE++, RLOO).

💥 Checklist before requesting a review

  • I have linked my PR to an issue (instruction)
  • My issue clearly describes the problem/feature/proposal, with diagrams/charts/table/code if possible
  • I have performed a self-review of my code
  • I have added thorough tests.
  • I have added docstrings for all the functions/methods I implemented

⭐️ Do you enjoy contributing to Colossal-AI?

  • 🌝 Yes, I do.
  • 🌚 No, I don't.

Tell us more if you don't enjoy contributing to Colossal-AI.

@sglucas sglucas requested a review from a team as a code owner July 23, 2025 02:04
@sglucas sglucas changed the base branch from main to grpo-latest July 23, 2025 02:05
# [minibatch_size x num_of_generation]
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
# [minibatch_size x num_of_generation]
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may be better to move the common calculations outside of the if statements for conciseness

# [minibatch_size x num_generations]
advantages = ((reward - reward_mean)).unsqueeze(dim=-1)

advantages_mean = advantages.mean(dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the advantages_mean always 0 as advantage is already zero-centered in the previous step?

advantages_std = advantages.std(dim=0)

advantages = (advantages - advantages_mean) / (advantages_std + 1e-4)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe consider double-checking the reinforce++ baseline advantage calculation. In reinforce ++, each sample's advantage is calculated by subtracting the mean reward of all generation in the global batch, not per prompt mean

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reinforce++, we should calculate norm adv using batch level mean and std.

@@ -0,0 +1,2 @@
4.51.0: qwen2.5 + grpo, qwen3 + grpo, cannot: llama2, llama3.2
4.47.0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove test log file

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this file.

@@ -227,13 +227,13 @@
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock

inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
train_model_config = dict(path=args.model, use_flash_attention_2=False, use_cache=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is flash attention not supported?

generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)

if args.backend == "transformers":
inference_model_config.update(
dict(
use_flash_attention_2=True,
use_flash_attention_2=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably also consider force num_generation to 1 for reinforce++

@TongLi3701 TongLi3701 changed the base branch from grpo-latest to main August 21, 2025 06:55
@TongLi3701 TongLi3701 changed the base branch from main to grpo-latest August 21, 2025 06:55
Copy link
Member

@TongLi3701 TongLi3701 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, we left some comments.

advantages_std = advantages.std(dim=0)

advantages = (advantages - advantages_mean) / (advantages_std + 1e-4)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reinforce++, we should calculate norm adv using batch level mean and std.

@@ -0,0 +1,2 @@
4.51.0: qwen2.5 + grpo, qwen3 + grpo, cannot: llama2, llama3.2
4.47.0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this file.

@sglucas sglucas closed this by deleting the head repository Aug 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FEATURE]: Add more training models and RLHF algorithms
3 participants