Sharing some recent findings and discussions with @youkaichao.
I actually don’t think online mode will help too much here. The main problem is some responses’ generations are too long, so in the end, you may be generating only 8 / 8000 prompts, so as long as you are generating them, online mode won’t help out too much.
@youkaichao gave a nice suggestion on doing “chunked generation”, basically generating 1k token at a time. See snippet here:
def chunked_generate_with_engines(prompts: list[list[int]], sampling_params: SamplingParams, max_chunk_size: int, vllm_engines: list[LLM]):
max_new_tokens = sampling_params.max_tokens
chunked_sampling_params = SamplingParams(
max_tokens=max_chunk_size,
temperature=sampling_params.temperature,
top_p=sampling_params.top_p,
top_k=sampling_params.top_k,
repetition_penalty=sampling_params.repetition_penalty,
)
prompt_and_responses = [prompt.copy() for prompt in prompts]
idxs = list(range(len(prompt_and_responses)))
dones = [0] * len(prompt_and_responses)
finish_reasons = ["length"] * len(prompt_and_responses)
max_iterations = max_new_tokens // max_chunk_size # don't do ceil div here because we don't want to over generate
for _ in range(max_iterations):
if all(dones):
break
not_done_idxs = [i for i in idxs if dones[i] == 0]
cur_prompt_and_responses = [prompt_and_responses[i] for i in not_done_idxs]
samples_per_engine = (len(cur_prompt_and_responses) + len(vllm_engines) - 1) // len(vllm_engines)
print(f"🔥 {samples_per_engine=}")
split_prompt_and_responses = [cur_prompt_and_responses[i : i + samples_per_engine] for i in range(0, len(cur_prompt_and_responses), samples_per_engine)]
futures = [
vllm_engine.generate.remote(sampling_params=chunked_sampling_params, prompt_token_ids=queries, use_tqdm=True)
for vllm_engine, queries in zip(vllm_engines, split_prompt_and_responses)
]
all_outputs = ray.get(futures)
for i, outputs in enumerate(all_outputs):
for j, output in enumerate(outputs):
seq_idx = not_done_idxs[i*samples_per_engine + j]
out = output.outputs[0] # we assume num_samples_per_prompt_rollout == 1
prompt_and_responses[seq_idx].extend(list(out.token_ids))
if out.finish_reason == "stop":
dones[seq_idx] = 1
finish_reasons[seq_idx] = out.finish_reason
response_ids = [prompt_and_response[len(prompt):] for prompt, prompt_and_response in zip(prompts, prompt_and_responses)]
return response_ids, finish_reasons
Say, I have 20 vllm workers, and a batch size of 8192, chunk of 1k generation at a time, basically the output looks like this:
2025-03-12T18:26:01.971Z 🔥 samples_per_engine=410
2025-03-12T18:26:01.971Z [Main Thread] 📦 Getting packed sequences from thread: 0.08 seconds
2025-03-12T18:26:01.971Z Number of training examples per device: B=26, packed sequence fraction of original sequences: 0.0382080078125
2025-03-12T18:26:01.971Z e[36m(PolicyTrainerRayProcess pid=2845, ip=10.95.0.157)e[0m Inference Calculation: 13.08 seconds
2025-03-12T18:26:01.971Z 🔥 samples_per_engine=20
2025-03-12T18:26:01.971Z 🔥 samples_per_engine=3
2025-03-12T18:26:01.971Z 🔥 samples_per_engine=1
2025-03-12T18:26:01.971Z 🔥 samples_per_engine=1
2025-03-12T18:26:01.971Z 🔥 samples_per_engine=1
After the first 2k tokens, 3 / 410 = 0.007 of the prompts are still generating. That’s what fundamentally caused the slowdown.
You can do techniques like early stopping, essentially cutting off at samples_per_engine=3
, but it has an undesirable side effects and less predictable in actual training.