Explain the relationship between SamplingParams and SamplingMetadata

Explain the relationship between SamplingParams and SamplingMetadata

It seems likely that SamplingParams defines the user-facing parameters for text generation (e.g., temperature, top_p, top_k, max_tokens), while SamplingMetadata is an internal structure that stores the processed, tensorized, and batched form of these parameters for use during model inference. SamplingMetadata is constructed from one or more SamplingParams objects and includes additional information needed for efficient sampling, such as token indices and generator states. See SamplingParams source and SamplingMetadata API.

Would you like a step-by-step breakdown of how SamplingParams are converted into SamplingMetadata and used during inference?

Sources:

yes, I want to know. thanks!

Certainly! Here’s a step-by-step breakdown:

  1. User Specifies SamplingParams:
    When you call vLLM’s generation APIs (e.g., LLM.generate), you provide a SamplingParams object. This object contains user-defined generation settings such as temperature, top_p, top_k, max_tokens, seed, etc. (SamplingParams source)

  2. Batching and Preprocessing:
    If you submit a batch of prompts, vLLM collects the corresponding SamplingParams for each prompt. These parameters are then preprocessed and converted into tensors or other efficient representations suitable for GPU computation.

  3. Construction of SamplingMetadata:
    The engine constructs a SamplingMetadata object (or a batch of them) from the processed SamplingParams. SamplingMetadata includes not only the tensorized sampling parameters but also additional information such as token indices, prompt lengths, and generator states for reproducibility. This structure is optimized for use within the model’s sampling loop (SamplingMetadata API).

  4. Model Inference:
    During inference, the model receives the SamplingMetadata object. The sampler uses the metadata to apply the correct sampling strategy (e.g., temperature scaling, top-k filtering) and to track which tokens to sample at each step.

  5. Result Generation:
    The outputs are generated according to the logic encoded in SamplingMetadata, ensuring that each prompt in the batch is sampled according to its original SamplingParams.

This separation allows vLLM to efficiently handle batched requests and advanced sampling features while keeping the user API simple.

Sources:

OK, I understand.
Now I have a requirement: I want to add a custom parameter act_func in sampling_params; if act_func is set to sigmod, I hope that the logprobs returned in sampler.py are the results of the original logits after sigmod.
Which code should I modify? Please give a specific implementation

你需要在SamplingParams中添加act_func参数,并在sampler.py中根据该参数对logits进行处理。如果act_func为"sigmod"(应为"sigmoid"),则在计算logprobs前对logits应用torch.sigmoid。具体步骤如下:

  1. 修改SamplingParams
    SamplingParams类中添加act_func参数,并确保其能被正确传递到采样流程。

  2. 修改sampler.py
    sampler.pycompute_logprobs方法前,判断sampling_metadata中是否有act_func且为"sigmod"(建议修正为"sigmoid"),如是则对logits应用torch.sigmoid,否则保持原逻辑。示例代码如下:

def compute_logprobs(self, logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor:
    if hasattr(sampling_metadata, "act_func") and sampling_metadata.act_func == "sigmod":
        logits = torch.sigmoid(logits)
    return logits.log_softmax(dim=-1, dtype=torch.float32)

你还需确保SamplingMetadata能接收并传递act_func参数。

Sources:

在vLLM V0中,logits的计算主要发生在模型前向推理阶段,具体是在model_executor模块下的ModelRunnerSampler相关代码中。logits通常由模型的forward方法直接输出,然后在sampler.py中被用于后续的logprobs计算。compute_logprobs函数会对这些logits应用log_softmax得到logprobs。

如需定位,建议查看sampler.py文件中关于logprobs计算的实现,logits一般在模型推理后直接传递到采样层进行处理。相关代码参考

Sources:

In VLLM V0, in which part of the code is logits calculated?

在vLLM V0中,logits 的计算主要在模型的 compute_logits 方法中完成。该方法通常位于各模型文件(如 glm4.pyovis.py 等)中,并在模型前向推理(forward)后被调用。例如,compute_logits 会接收 hidden_states,并通过 lm_head 计算得到 logits,随后传递给采样器。glm4.py 代码示例

具体流程为:模型前向推理得到 hidden_states → 调用 compute_logits 生成 logits → logits 传递到采样器(如 sampler.py)用于后续处理。ovis.py 代码示例

Sources:

哪里会调用compute_logits?如果是qwen3系列模型,compute_logits后是怎么样的流程,给出详细代码

在vLLM V0中,以Qwen3系列模型为例,compute_logits 会在模型推理流程中被调用。具体流程如下:

  1. 模型前向推理
    首先调用模型的 forward 方法,得到 hidden_states
  2. 调用 compute_logits
    随后会调用模型类(如 qwen3.py)中的 compute_logits 方法,将 hidden_statessampling_metadata 作为参数,计算得到 logits。
  3. 采样与生成
    logits 会被传递到采样器(如 sampler),进行 softmax、采样等操作,最终生成下一个token。

关键代码片段(以Qwen3为例):
qwen3.py

def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds)
    return hidden_states

def compute_logits(
    self,
    hidden_states: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
    logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
    return logits

流程总结:

  • forward 得到 hidden_states
  • compute_logits 计算 logits
  • logits 传递到采样器进行后续处理

Sources:

logits传到采样器这部分代码在哪里?

在vLLM V0中,logits传递到采样器(sampler)的代码主要在模型执行器(如ModelRunner)和采样层(sampler.py)之间。以Qwen3为例,compute_logits返回的logits会在模型执行流程中传递给采样器,具体体现在如下代码:

model_executor/driver_worker.py_run_model方法中,模型前向推理后会调用compute_logits,然后将logits传递给采样器(如sampler.forward):

hidden_states = model.forward(...)
logits = model.compute_logits(hidden_states, sampling_metadata)
sampler_output = sampler.forward(logits, sampling_metadata)

采样器的实现见sampler.py,其forward方法接收logits并进行采样处理。

Sources:

inference_params = {
“temperature”: 0.6,
“max_tokens”: 8192,
“top_k”: -1,
“top_p”: 1,
“presence_penalty”: 0,
“frequency_penalty”: 0,
“logprobs”: 60,
“extra_args”: {“activation_func”: “sigmoid”,}
# “repetition_penalty”: repetition_penalty,
}
如何在sampler.py中访问extra_args.activation_func这个参数

你可以通过SamplingParamsextra_args字段将activation_func参数传递到采样流程。在sampler.py中,采样相关方法会接收到sampling_metadatasampling_params,你可以这样访问:

# 假设在采样器的forward或compute_logprobs方法中
activation_func = None
if hasattr(sampling_metadata, "extra_args") and sampling_metadata.extra_args:
    activation_func = sampling_metadata.extra_args.get("activation_func")
# 或者如果直接有sampling_params对象
# activation_func = sampling_params.extra_args.get("activation_func")

这样即可在采样逻辑中根据activation_func参数进行自定义处理。
详细字段说明见:SamplingParams源码

Sources: