RL Training with vLLM Rollout: How to Mitigate Load Imbalance from Variable Response Lengths

We’re using vLLM for RL training with 32xA100 GPUs and a batch size of ~8000.

  1. Load Imbalance in SPMD Mode:
    Prompts are split evenly across workers, but response length variance causes severe load imbalance. Would switching to online mode fix this by dynamically balancing workloads?
  2. Performance Comparison:
    At this scale, how much worse is the throughput of online mode compared to offline/SPMD mode?
  3. Optimization Tips:
    Are there best practices to reduce overhead (e.g., scheduling, communication) in online mode for large batches?

To use online mode for RL training, you would need to expose /collective_rpc endpoint.

since this has important security concern, we need to add some vllm_rpc_key , and use the key to sign the message received from an arbitrary /collective_rpc request.

some reference for how to add an endpoint, e.g. [core] add sleep and wake up endpoint and v1 support by youkaichao · Pull Request #12987 · vllm-project/vllm · GitHub

with these support (and the /sleep, /wake_up endpoint) , you should be able to use vllm online mode for RL training, and compare it against offline mode.

1 Like

Sharing some recent findings and discussions with @youkaichao.

I actually don’t think online mode will help too much here. The main problem is some responses’ generations are too long, so in the end, you may be generating only 8 / 8000 prompts, so as long as you are generating them, online mode won’t help out too much.

@youkaichao gave a nice suggestion on doing “chunked generation”, basically generating 1k token at a time. See snippet here:

def chunked_generate_with_engines(prompts: list[list[int]], sampling_params: SamplingParams, max_chunk_size: int, vllm_engines: list[LLM]):
    max_new_tokens = sampling_params.max_tokens
    chunked_sampling_params = SamplingParams(
        max_tokens=max_chunk_size,
        temperature=sampling_params.temperature,
        top_p=sampling_params.top_p,
        top_k=sampling_params.top_k,
        repetition_penalty=sampling_params.repetition_penalty,
    )
    prompt_and_responses = [prompt.copy() for prompt in prompts]
    idxs = list(range(len(prompt_and_responses)))
    dones = [0] * len(prompt_and_responses)
    finish_reasons = ["length"] * len(prompt_and_responses)
    max_iterations = max_new_tokens // max_chunk_size # don't do ceil div here because we don't want to over generate
    for _ in range(max_iterations):
        if all(dones):
            break
        not_done_idxs = [i for i in idxs if dones[i] == 0]
        cur_prompt_and_responses = [prompt_and_responses[i] for i in not_done_idxs]
        samples_per_engine = (len(cur_prompt_and_responses) + len(vllm_engines) - 1) // len(vllm_engines)
        print(f"🔥 {samples_per_engine=}")
        split_prompt_and_responses = [cur_prompt_and_responses[i : i + samples_per_engine] for i in range(0, len(cur_prompt_and_responses), samples_per_engine)]
        futures = [
            vllm_engine.generate.remote(sampling_params=chunked_sampling_params, prompt_token_ids=queries, use_tqdm=True)
            for vllm_engine, queries in zip(vllm_engines, split_prompt_and_responses)
        ]
        all_outputs = ray.get(futures)
        for i, outputs in enumerate(all_outputs):
            for j, output in enumerate(outputs):
                seq_idx = not_done_idxs[i*samples_per_engine + j]
                out = output.outputs[0] # we assume num_samples_per_prompt_rollout == 1
                prompt_and_responses[seq_idx].extend(list(out.token_ids))
                if out.finish_reason == "stop":
                    dones[seq_idx] = 1
                    finish_reasons[seq_idx] = out.finish_reason

    response_ids = [prompt_and_response[len(prompt):] for prompt, prompt_and_response in zip(prompts, prompt_and_responses)]
    return response_ids, finish_reasons

Say, I have 20 vllm workers, and a batch size of 8192, chunk of 1k generation at a time, basically the output looks like this:

2025-03-12T18:26:01.971Z 🔥 samples_per_engine=410
2025-03-12T18:26:01.971Z [Main Thread] 📦 Getting packed sequences from thread: 0.08 seconds
2025-03-12T18:26:01.971Z Number of training examples per device: B=26, packed sequence fraction of original sequences: 0.0382080078125
2025-03-12T18:26:01.971Z e[36m(PolicyTrainerRayProcess pid=2845, ip=10.95.0.157)e[0m Inference Calculation: 13.08 seconds
2025-03-12T18:26:01.971Z 🔥 samples_per_engine=20
2025-03-12T18:26:01.971Z 🔥 samples_per_engine=3
2025-03-12T18:26:01.971Z 🔥 samples_per_engine=1
2025-03-12T18:26:01.971Z 🔥 samples_per_engine=1
2025-03-12T18:26:01.971Z 🔥 samples_per_engine=1

After the first 2k tokens, 3 / 410 = 0.007 of the prompts are still generating. That’s what fundamentally caused the slowdown.

You can do techniques like early stopping, essentially cutting off at samples_per_engine=3, but it has an undesirable side effects and less predictable in actual training.

3 Likes

Thanks a lot! I’ll give it a try

That’s a great idea, but I noticed that simple chunked generation did not significantly improve tokens per second. Does the value represents the throughput during the generation phase?
Does the conclusion here mean that simple chunked generation does not improve overall performance, but using early stop to discard long data would lead to degraded results?
Is there any other optimization potential here, such as using streaming generation?