Explain the relationship between SamplingParams and SamplingMetadata
It seems likely that SamplingParams
defines the user-facing parameters for text generation (e.g., temperature, top_p, top_k, max_tokens), while SamplingMetadata
is an internal structure that stores the processed, tensorized, and batched form of these parameters for use during model inference. SamplingMetadata
is constructed from one or more SamplingParams
objects and includes additional information needed for efficient sampling, such as token indices and generator states. See SamplingParams source and SamplingMetadata API.
Would you like a step-by-step breakdown of how SamplingParams are converted into SamplingMetadata and used during inference?
Sources:
yes, I want to know. thanks!
Certainly! Here’s a step-by-step breakdown:
-
User Specifies SamplingParams:
When you call vLLM’s generation APIs (e.g.,LLM.generate
), you provide aSamplingParams
object. This object contains user-defined generation settings such astemperature
,top_p
,top_k
,max_tokens
,seed
, etc. (SamplingParams source) -
Batching and Preprocessing:
If you submit a batch of prompts, vLLM collects the correspondingSamplingParams
for each prompt. These parameters are then preprocessed and converted into tensors or other efficient representations suitable for GPU computation. -
Construction of SamplingMetadata:
The engine constructs aSamplingMetadata
object (or a batch of them) from the processedSamplingParams
.SamplingMetadata
includes not only the tensorized sampling parameters but also additional information such as token indices, prompt lengths, and generator states for reproducibility. This structure is optimized for use within the model’s sampling loop (SamplingMetadata API). -
Model Inference:
During inference, the model receives theSamplingMetadata
object. The sampler uses the metadata to apply the correct sampling strategy (e.g., temperature scaling, top-k filtering) and to track which tokens to sample at each step. -
Result Generation:
The outputs are generated according to the logic encoded inSamplingMetadata
, ensuring that each prompt in the batch is sampled according to its originalSamplingParams
.
This separation allows vLLM to efficiently handle batched requests and advanced sampling features while keeping the user API simple.
Sources:
OK, I understand.
Now I have a requirement: I want to add a custom parameter act_func in sampling_params; if act_func is set to sigmod, I hope that the logprobs returned in sampler.py are the results of the original logits after sigmod.
Which code should I modify? Please give a specific implementation
你需要在SamplingParams
中添加act_func
参数,并在sampler.py
中根据该参数对logits进行处理。如果act_func
为"sigmod"(应为"sigmoid"),则在计算logprobs前对logits应用torch.sigmoid
。具体步骤如下:
-
修改SamplingParams
在SamplingParams类中添加act_func
参数,并确保其能被正确传递到采样流程。 -
修改sampler.py
在sampler.py的compute_logprobs
方法前,判断sampling_metadata
中是否有act_func
且为"sigmod"(建议修正为"sigmoid"),如是则对logits应用torch.sigmoid
,否则保持原逻辑。示例代码如下:
def compute_logprobs(self, logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor:
if hasattr(sampling_metadata, "act_func") and sampling_metadata.act_func == "sigmod":
logits = torch.sigmoid(logits)
return logits.log_softmax(dim=-1, dtype=torch.float32)
你还需确保SamplingMetadata
能接收并传递act_func
参数。
Sources:
在vLLM V0中,logits的计算主要发生在模型前向推理阶段,具体是在model_executor
模块下的ModelRunner
或Sampler
相关代码中。logits通常由模型的forward方法直接输出,然后在sampler.py中被用于后续的logprobs计算。compute_logprobs
函数会对这些logits应用log_softmax
得到logprobs。
如需定位,建议查看sampler.py
文件中关于logprobs计算的实现,logits一般在模型推理后直接传递到采样层进行处理。相关代码参考。
Sources:
In VLLM V0, in which part of the code is logits calculated?
在vLLM V0中,logits 的计算主要在模型的 compute_logits
方法中完成。该方法通常位于各模型文件(如 glm4.py
、ovis.py
等)中,并在模型前向推理(forward)后被调用。例如,compute_logits
会接收 hidden_states,并通过 lm_head 计算得到 logits,随后传递给采样器。glm4.py 代码示例。
具体流程为:模型前向推理得到 hidden_states → 调用 compute_logits
生成 logits → logits 传递到采样器(如 sampler.py)用于后续处理。ovis.py 代码示例。
Sources:
哪里会调用compute_logits?如果是qwen3系列模型,compute_logits后是怎么样的流程,给出详细代码
在vLLM V0中,以Qwen3系列模型为例,compute_logits
会在模型推理流程中被调用。具体流程如下:
- 模型前向推理
首先调用模型的forward
方法,得到hidden_states
。 - 调用 compute_logits
随后会调用模型类(如 qwen3.py)中的compute_logits
方法,将hidden_states
和sampling_metadata
作为参数,计算得到 logits。 - 采样与生成
logits 会被传递到采样器(如 sampler),进行 softmax、采样等操作,最终生成下一个token。
关键代码片段(以Qwen3为例):
qwen3.py
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
return logits
流程总结:
forward
得到hidden_states
compute_logits
计算 logits- logits 传递到采样器进行后续处理
Sources:
logits传到采样器这部分代码在哪里?
在vLLM V0中,logits传递到采样器(sampler)的代码主要在模型执行器(如ModelRunner
)和采样层(sampler.py
)之间。以Qwen3为例,compute_logits
返回的logits会在模型执行流程中传递给采样器,具体体现在如下代码:
在model_executor/driver_worker.py的_run_model
方法中,模型前向推理后会调用compute_logits
,然后将logits传递给采样器(如sampler.forward
):
hidden_states = model.forward(...)
logits = model.compute_logits(hidden_states, sampling_metadata)
sampler_output = sampler.forward(logits, sampling_metadata)
采样器的实现见sampler.py,其forward
方法接收logits并进行采样处理。
Sources:
inference_params = {
“temperature”: 0.6,
“max_tokens”: 8192,
“top_k”: -1,
“top_p”: 1,
“presence_penalty”: 0,
“frequency_penalty”: 0,
“logprobs”: 60,
“extra_args”: {“activation_func”: “sigmoid”,}
# “repetition_penalty”: repetition_penalty,
}
如何在sampler.py中访问extra_args.activation_func这个参数
你可以通过SamplingParams
的extra_args
字段将activation_func
参数传递到采样流程。在sampler.py
中,采样相关方法会接收到sampling_metadata
或sampling_params
,你可以这样访问:
# 假设在采样器的forward或compute_logprobs方法中
activation_func = None
if hasattr(sampling_metadata, "extra_args") and sampling_metadata.extra_args:
activation_func = sampling_metadata.extra_args.get("activation_func")
# 或者如果直接有sampling_params对象
# activation_func = sampling_params.extra_args.get("activation_func")
这样即可在采样逻辑中根据activation_func
参数进行自定义处理。
详细字段说明见:SamplingParams源码。
Sources: