Hey guys! Big fan of vLLM
I notice that the logprobs returned by vLLM can be quite far off from the ones calculated by transformers. Logits also very often just round to -inf when theyre perfectly within dynamic range of bf16 or even fp16. Is this a known issue? Are there any works to reduce this affect atm?
These logprobs are used by rl losses like GRPO to scale the reward signal and we observe the convergence is quite sensitive to them.
Thanks for the report!
Can you share the code for how you are calculating “HF Logprobs”? Are these the logits or the logprobabilites? Would be interested to reproduce on my side
The script requires 2 GPUs because im loading vllm engine on 0 and HF on 1:
# VLLM Side
import torch
from vllm import LLM, SamplingParams
TEMPERATURE = 0.7
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", enforce_eager=True)
prompts = [
"One of the most important things in life is to",
"The best way to learn is to",
]
outputs = llm.generate(
prompts,
sampling_params=SamplingParams(
max_tokens=100,
temperature=TEMPERATURE,
logprobs=2,
),
)
save_stuff = []
for output in outputs:
print(len(output.outputs[0].token_ids), len(output.outputs[0].logprobs))
for token, logprob in zip(output.outputs[0].token_ids, output.outputs[0].logprobs):
print(token, logprob)
save_stuff.append(
{
"input_ids": output.prompt_token_ids,
"output_ids": output.outputs[0].token_ids,
"logprobs": output.outputs[0].logprobs,
}
)
# HF Side
torch.cuda.set_device(1)
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", torch_dtype=torch.bfloat16, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
for output in save_stuff:
token_ids = torch.tensor([*output["input_ids"], *output["output_ids"]], device="cuda").unsqueeze(0)
print(token_ids.shape, token_ids)
with torch.inference_mode():
model_outputs = model(token_ids)
print(model_outputs[0].shape)
real_logprobs = F.log_softmax(model_outputs[0] / TEMPERATURE, dim=-1)
print(real_logprobs.shape)
for i in range(len(output["logprobs"])):
print("===", output["output_ids"][i], "===")
for key in output["logprobs"][i]:
print(key, output["logprobs"][i][key], "HF logprobs:", real_logprobs[0, i -1 + len(output["input_ids"])][key].item())
Small update from my side. It’s probably some numerical issue either from the algebraic rewrite of “prefill + decode” vs “prefill only” or the kernels used by the model impls. The relative error is < 1e-2 in fp32.
It does seem like the calculation for log_softmax used by vllm is more unstable though. I put a hook in the logits_processor to get the logits and did F.log_softmax
and the relative error seems to be consistently lower than the one returned by vllm.
=== 537 ===
537 Logprob(logprob=-0.6684995889663696, rank=1, decoded_token=' not') HF logprobs: -0.6015625 Hook logprobs: -0.6015625
Prob: 0.54795478749992, VLLM: 0.5124769272585249, Hook: 0.54795478749992
=== 2704 ===
=== 13 ===
=== 10696 ===
=== 358 ===
=== 1265 ===
=== 1744 ===
=== 911 ===
Relative logprob errors
VLLM: max=0.38462623073334895, mean=0.05062537796546106, stdev=0.05760658469739193, median=0.02781234832067509, min=0.0
Hook: max=0.38461559143521973, mean=0.04105125529860964, stdev=0.0626278072035839, median=0.013333333333475555, min=0.0
Absolute prob errors
VLLM: max=0.08872799262859904, mean=0.005587153872456963, stdev=0.011046970011292146, median=0.00012709689606727972, min=0.0
Hook: max=0.06310779300433156, mean=0.00431247809215178, stdev=0.0111055722111331, median=1.0116216796347255e-05, min=0.0
Updated script
# VLLM Side
import torch
from vllm import LLM, SamplingParams
import math
TEMPERATURE = 0.7
DTYPE = torch.bfloat16
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", dtype=DTYPE, enforce_eager=True)
model = llm.llm_engine.model_executor.driver_worker.model_runner.model
saved_outputs = []
def logits_processor_hook(module, input, output):
assert isinstance(output, torch.Tensor)
saved_outputs.append(output.clone())
model.logits_processor.register_forward_hook(logits_processor_hook)
prompts = [
"One of the most important things in life is to",
"The answer to 1 + 1 is",
]
outputs = llm.generate(
prompts,
sampling_params=SamplingParams(
max_tokens=512,
temperature=TEMPERATURE,
logprobs=2,
),
)
save_stuff = []
for output in outputs:
assert len(output.outputs[0].token_ids) == len(output.outputs[0].logprobs)
#for token, logprob in zip(output.outputs[0].token_ids, output.outputs[0].logprobs):
#print(token, logprob)
save_stuff.append(
{
"input_ids": output.prompt_token_ids,
"output_ids": output.outputs[0].token_ids,
"logprobs": output.outputs[0].logprobs,
}
)
# HF Side
torch.cuda.set_device(1)
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", torch_dtype=DTYPE, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
seq_id = 0
vllm_errs = []
hook_errs = []
vllm_prob_errs = []
hook_prob_errs = []
for output in save_stuff:
token_ids = torch.tensor([*output["input_ids"], *output["output_ids"]], device="cuda").unsqueeze(0)
print(token_ids.shape)
with torch.inference_mode():
model_outputs = model(token_ids)
print(model_outputs[0].shape)
real_logprobs = F.log_softmax(model_outputs[0] / TEMPERATURE, dim=-1)
print(real_logprobs.shape)
for i in range(len(output["logprobs"])):
print("===", output["output_ids"][i], "===")
hook_logprobs = F.log_softmax(saved_outputs[i][seq_id] / TEMPERATURE, dim=-1)
for key in output["logprobs"][i]:
_real_logprobs = real_logprobs[0, i -1 + len(output["input_ids"])]
vllm_rel_err = abs((output["logprobs"][i][key].logprob - _real_logprobs[key].item()) / (_real_logprobs[key].item() + 1e-10))
hook_rel_err = abs((hook_logprobs[key].item() - _real_logprobs[key].item()) / (_real_logprobs[key].item() + 1e-10))
vllm_errs.append(vllm_rel_err)
hook_errs.append(hook_rel_err)
vllm_prob = math.exp(output["logprobs"][i][key].logprob)
hook_prob = math.exp(hook_logprobs[key].item())
real_prob = math.exp(_real_logprobs[key].item())
vllm_prob_err = abs(vllm_prob - real_prob)
hook_prob_err = abs(hook_prob - real_prob)
vllm_prob_errs.append(vllm_prob_err)
hook_prob_errs.append(hook_prob_err)
if (vllm_rel_err > 0.1 or hook_rel_err > 0.1) and real_prob < 0.9:
print(
key, output["logprobs"][i][key],
"HF logprobs:", real_logprobs[0, i -1 + len(output["input_ids"])][key].item(),
"Hook logprobs:", hook_logprobs[key].item(),
)
print(f"Prob: {real_prob}, VLLM: {vllm_prob}, Hook: {hook_prob}")
seq_id += 1
from statistics import mean, stdev, median
print("Relative logprob errors")
print(f"VLLM: max={max(vllm_errs)}, mean={mean(vllm_errs)}, stdev={stdev(vllm_errs)}, median={median(vllm_errs)}, min={min(vllm_errs)}")
print(f"Hook: max={max(hook_errs)}, mean={mean(hook_errs)}, stdev={stdev(hook_errs)}, median={median(hook_errs)}, min={min(hook_errs)}")
print("Absolute prob errors")
print(f"VLLM: max={max(vllm_prob_errs)}, mean={mean(vllm_prob_errs)}, stdev={stdev(vllm_prob_errs)}, median={median(vllm_prob_errs)}, min={min(vllm_prob_errs)}")
print(f"Hook: max={max(hook_prob_errs)}, mean={mean(hook_prob_errs)}, stdev={stdev(hook_prob_errs)}, median={median(hook_prob_errs)}, min={min(hook_prob_errs)}")
Thank you for your detailed exploration — this is an incredibly helpful report.
I think we can evaluate to make sure there is not a performance footgun of F.torch_log_softmax
and make this switch if you find it more stable. I will dig in a bit this afternoon
According to my experiment,It’s not caused by softmax function or activation layer.
I Guess the main differences come from the error accumulated in the progress of matrix multiply implementation.
@robertshaw
Could @Jackmin801 's issue also be related to these issues? (Sorry I can’t put more than 2 links at this moment, I will use Issue Number instead)
Open Issues:
- [Bug]: The accuracy of multiple cards and single card is inconsistent · Issue #13801 · vllm-project/vllm
- [Bug]: prefix-caching: inconsistent completions · Issue #5543 · vllm-project/vllm
- [Bug]: topk=1 and temperature=0 cause different output in vllm · Issue #5404 · vllm-project/vllm
- [Bug]: Inconsistent Responses with VLLM When Batch Size > 1 even temperature = 0 · Issue #5898 · vllm-project/vllm
- [Bug]: Outputs are different at separate runs · Issue #10074 · vllm-project/vllm
Closed Issues:
- [Bug]: Concurrent requests messing up GREEDY responses · Issue #5607 · vllm-project/vllm