Skip to content

Conversation

joecummings
Copy link
Member

@joecummings joecummings commented Oct 10, 2025

Context

  1. We need to document where in vLLM our current code comes from to make debugging easier
  2. We want to make sure we aren't missing any critical vLLM setup features
  3. General cleanup / docstrings

(See specific changes & comments in the PR files themselves)

How do you know I didn't break things?

  1. python tests/integration_tests/test_vllm_policy_correctness.py

Results:

Warning: setting HYPERACTOR_CODEC_MAX_FRAME_LENGTH since this needs to be set to enable large RPC calls via Monarch
INFO 10-13 11:09:38 [__init__.py:235] Automatically detected platform cuda.
INFO 10-13 11:09:47 [config.py:1604] Using max model len 512
INFO 10-13 11:09:48 [config.py:2434] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 10-13 11:09:49 [core.py:572] Waiting for init message from front-end.
INFO 10-13 11:09:49 [core.py:71] Initializing a V1 LLM engine (v0.10.1.dev0+g6d8d0a24c.d20250930) with config: model='facebook/opt-125m', speculative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=512, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=facebook/opt-125m, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":0,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":[],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":0,"cudagraph_capture_sizes":[],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":0,"local_cache_dir":null}
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
INFO 10-13 11:09:52 [parallel_state.py:1102] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
WARNING 10-13 11:09:52 [topk_topp_sampler.py:59] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
INFO 10-13 11:09:52 [gpu_model_runner.py:1843] Starting to load model facebook/opt-125m...
INFO 10-13 11:09:52 [gpu_model_runner.py:1875] Loading model from scratch...
INFO 10-13 11:09:52 [cuda.py:290] Using Flash Attention backend on V1 engine.
INFO 10-13 11:09:52 [weight_utils.py:296] Using model weights format ['*.bin']
INFO 10-13 11:09:53 [default_loader.py:262] Loading weights took 0.17 seconds
INFO 10-13 11:09:53 [gpu_model_runner.py:1892] Model loading took 0.2393 GiB and 0.565655 seconds
INFO 10-13 11:09:55 [gpu_worker.py:255] Available KV cache memory: 8.93 GiB
INFO 10-13 11:09:55 [kv_cache_utils.py:833] GPU KV cache size: 260,048 tokens
INFO 10-13 11:09:55 [kv_cache_utils.py:837] Maximum concurrency for 512 tokens per request: 507.91x
INFO 10-13 11:09:55 [core.py:193] init engine (profile, create kv cache, warmup model) took 1.49 seconds
INFO 10-13 11:09:55 [loggers.py:141] Engine 000: vllm cache_config_info with initialization after num_gpu_blocks is: 16253
Spawning Service for Policy
Launcher not provided, remote allocations will not work.
INFO 10-13 11:10:02 [config.py:1604] Using max model len 512
INFO 10-13 11:10:02 [config.py:2434] Chunked prefill is enabled with max_num_batched_tokens=16384.
[0] INFO 10-13 11:10:04 [__init__.py:235] Automatically detected platform cuda.
[0] INFO 10-13 11:10:04 [__init__.py:235] Automatically detected platform cuda.
[0] WARNING 10-13 11:10:06 [multiproc_worker_utils.py:307] Reducing Torch parallelism from 368 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
[0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] INFO 10-13 11:10:09 [parallel_state.py:1102] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
[0] WARNING 10-13 11:10:09 [topk_topp_sampler.py:59] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
[0] INFO 10-13 11:10:09 [gpu_model_runner.py:1843] Starting to load model facebook/opt-125m...
[0] INFO 10-13 11:10:09 [gpu_model_runner.py:1875] Loading model from scratch...
[0] INFO 10-13 11:10:09 [cuda.py:290] Using Flash Attention backend on V1 engine.
[0] INFO 10-13 11:10:10 [weight_utils.py:296] Using model weights format ['*.bin']
[0] INFO 10-13 11:10:10 [default_loader.py:262] Loading weights took 0.19 seconds
[0] INFO 10-13 11:10:11 [gpu_model_runner.py:1892] Model loading took 0.2389 GiB and 0.609802 seconds
[0] INFO 10-13 11:10:18 [config.py:1604] Using max model len 512
[0] INFO 10-13 11:10:19 [config.py:2434] Chunked prefill is enabled with max_num_batched_tokens=16384.
[0] INFO 10-13 11:10:20 [gpu_worker.py:255] Available KV cache memory: 7.31 GiB
[0] INFO 10-13 11:10:20 [kv_cache_utils.py:833] GPU KV cache size: 212,976 tokens
[0] INFO 10-13 11:10:20 [kv_cache_utils.py:837] Maximum concurrency for 512 tokens per request: 415.97x
Models ready. Generating outputs...

INFO 10-13 11:10:21 [async_llm.py:269] Added request 1.
[0] �[34m[Policy-0/1] 2025-10-13 11:10:22 WARNING�[0m Skipping metric collection. Metric logging backends (e.g. wandb) were not initialized. This happens when you try to use `record_metric` before calling `init_backends`. To disable this warning, please call in your main file:
[0] `mlogger = await get_or_create_metric_logger()`
[0] `await mlogger.init_backends.call_one(logging_config)`
[0] or set env variable `FORGE_DISABLE_METRICS=True`
INFO 10-13 11:10:23 [async_llm.py:269] Added request 2.
INFO 10-13 11:10:23 [async_llm.py:269] Added request 3.
INFO 10-13 11:10:23 [async_llm.py:269] Added request 4.
INFO 10-13 11:10:24 [async_llm.py:269] Added request 5.
✅ Outputs are the same!
Health loop stopped gracefully.
INFO 10-13 11:10:28 [config.py:1604] Using max model len 512
INFO 10-13 11:10:29 [config.py:2434] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 10-13 11:10:29 [core.py:572] Waiting for init message from front-end.
INFO 10-13 11:10:29 [core.py:71] Initializing a V1 LLM engine (v0.10.1.dev0+g6d8d0a24c.d20250930) with config: model='facebook/opt-125m', speculative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=512, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=facebook/opt-125m, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":0,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":[],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":0,"cudagraph_capture_sizes":[],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":0,"local_cache_dir":null}
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
INFO 10-13 11:10:35 [parallel_state.py:1102] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
WARNING 10-13 11:10:35 [topk_topp_sampler.py:59] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
INFO 10-13 11:10:35 [gpu_model_runner.py:1843] Starting to load model facebook/opt-125m...
INFO 10-13 11:10:36 [gpu_model_runner.py:1875] Loading model from scratch...
INFO 10-13 11:10:36 [cuda.py:290] Using Flash Attention backend on V1 engine.
INFO 10-13 11:10:36 [weight_utils.py:296] Using model weights format ['*.bin']
INFO 10-13 11:10:36 [default_loader.py:262] Loading weights took 0.17 seconds
INFO 10-13 11:10:37 [gpu_model_runner.py:1892] Model loading took 0.2393 GiB and 0.737609 seconds
INFO 10-13 11:10:39 [gpu_worker.py:255] Available KV cache memory: 8.93 GiB
INFO 10-13 11:10:39 [kv_cache_utils.py:833] GPU KV cache size: 260,048 tokens
INFO 10-13 11:10:39 [kv_cache_utils.py:837] Maximum concurrency for 512 tokens per request: 507.91x
INFO 10-13 11:10:39 [core.py:193] init engine (profile, create kv cache, warmup model) took 1.90 seconds
INFO 10-13 11:10:39 [loggers.py:141] Engine 000: vllm cache_config_info with initialization after num_gpu_blocks is: 16253
Spawning Service for Policy
INFO 10-13 11:10:46 [config.py:1604] Using max model len 512
INFO 10-13 11:10:46 [config.py:2434] Chunked prefill is enabled with max_num_batched_tokens=16384.
[0] INFO 10-13 11:10:48 [__init__.py:235] Automatically detected platform cuda.
[0] INFO 10-13 11:10:48 [__init__.py:235] Automatically detected platform cuda.
[0] WARNING 10-13 11:10:50 [multiproc_worker_utils.py:307] Reducing Torch parallelism from 368 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
[0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[0] INFO 10-13 11:10:53 [parallel_state.py:1102] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
[0] WARNING 10-13 11:10:53 [topk_topp_sampler.py:59] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
[0] INFO 10-13 11:10:53 [gpu_model_runner.py:1843] Starting to load model facebook/opt-125m...
[0] INFO 10-13 11:10:53 [gpu_model_runner.py:1875] Loading model from scratch...
[0] INFO 10-13 11:10:54 [cuda.py:290] Using Flash Attention backend on V1 engine.
[0] INFO 10-13 11:10:54 [weight_utils.py:296] Using model weights format ['*.bin']
[0] INFO 10-13 11:10:54 [default_loader.py:262] Loading weights took 0.19 seconds
[0] INFO 10-13 11:10:55 [gpu_model_runner.py:1892] Model loading took 0.2389 GiB and 0.631622 seconds
[0] INFO 10-13 11:11:02 [config.py:1604] Using max model len 512
[0] INFO 10-13 11:11:03 [config.py:2434] Chunked prefill is enabled with max_num_batched_tokens=16384.
[0] INFO 10-13 11:11:04 [gpu_worker.py:255] Available KV cache memory: 7.31 GiB
[0] INFO 10-13 11:11:05 [kv_cache_utils.py:833] GPU KV cache size: 212,976 tokens
[0] INFO 10-13 11:11:05 [kv_cache_utils.py:837] Maximum concurrency for 512 tokens per request: 415.97x
Models ready. Starting KV cache test...
INFO 10-13 11:11:05 [async_llm.py:269] Added request first_16.
[0] �[34m[Policy-0/1] 2025-10-13 11:11:06 WARNING�[0m Skipping metric collection. Metric logging backends (e.g. wandb) were not initialized. This happens when you try to use `record_metric` before calling `init_backends`. To disable this warning, please call in your main file:
[0] `mlogger = await get_or_create_metric_logger()`
[0] `await mlogger.init_backends.call_one(logging_config)`
[0] or set env variable `FORGE_DISABLE_METRICS=True`
INFO 10-13 11:11:07 [async_llm.py:269] Added request second_16_use_first_block.
INFO 10-13 11:11:07 [async_llm.py:269] Added request use_both_blocks.
INFO 10-13 11:11:08 [block_pool.py:321] Successfully reset prefix cache
[0] INFO 10-13 11:11:08 [block_pool.py:321] Successfully reset prefix cache
INFO 10-13 11:11:08 [async_llm.py:269] Added request use_no_blocks_bc_cache_cleared.

✅ Prefix cache usage is the same!
Health loop stopped gracefully.
  1. python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

Results: https://wandb.ai/jcummings/grpo-training/runs/kt1m006u?nw=nwuserjcummings

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 10, 2025
@felipemello1 felipemello1 self-assigned this Oct 11, 2025
@joecummings joecummings marked this pull request as ready for review October 13, 2025 18:16
trace_headers=None,
priority=priority,
data_parallel_rank=None,
data_parallel_rank=None, # We do not support DP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible someone tries to set DP through the vLLM args? Or in other words, any possibility this fails silently because this is a comment instead of an error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha there's a lot of that b/c we choose not to have a shim layer between what we support and the span of config options offered through vLLM.

I can try to call out the egregious cases like this one though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok fair haha

I think this isn't as important for today, and maybe we can cover ourselves by just commenting that our design choice broadly comes with implications that things don't work / this is experimental...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we choose not to have a shim layer between what we support and the span of config options

This is the right move for simplifying things, I'd be shocked if we don't end up talking about adding shim back by EOY haha

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be easy to add an assert in the post init

Copy link
Contributor

@Jack-Khuu Jack-Khuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If grpo.main works, ship it

joecummings and others added 2 commits October 13, 2025 15:11
Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>
Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>
@joecummings joecummings merged commit 327f5f1 into meta-pytorch:main Oct 13, 2025
5 checks passed
@joecummings joecummings deleted the remaining-policy-cleanup branch October 13, 2025 19:12
Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a number of context and args comments, but should be all relatively small

TORCHSTORE_USE_RDMA.get_value() == 0
) # torchstore currently only accepts 0 or 1
# Gets set up by setup
# Remaining variables are initialized in self.setup()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be necessary. Everything up here are config like objects. So dataclasses, direct values, or callables. You shouldn't have to specify that variables are created in setup as that's the pattern everywhere. Also, I'd be fine with removing lora_request if we don't support it now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah fair, just didn't know if you had left this in here for readability at a glance.

prompt_str, request = self.processor.process_inputs(
request_id=request_id,
prompt=prompt_dict,
prompt={"prompt": prompt},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably fine but it takes us a bit further from vllm

trace_headers=None,
priority=priority,
data_parallel_rank=None,
data_parallel_rank=None, # We do not support DP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be easy to add an assert in the post init


@dataclass
class PolicyWorker(ForgeActor):
"""Mirrors a vLLM GPUWorker
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make this match the parent worker so that we can just pipe the args, kwargs though in launch and then build vllm config here as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants