VLLM V1 Scheduler: Inconsistent Request Scheduling Under Token Budget Limit
Problem Description
The VLLM V1 scheduler exhibits unexpected behavior when scheduling requests that should theoretically fit within the token budget limit. Despite having sufficient token budget for all requests in the prefill phase, the scheduler inconsistently schedules only a subset of requests in the first scheduling round.
Environment
- Model: Qwen3-32B
- Hardware: Single H100 (80GB)
- VLLM Version: V1 (VLLM_USE_V1=1)
- Token Budget Limit: 16,384 (default)
- GPU Memory Utilization: 0.95
- Max Model Length: 32,768
Test Configuration
- Input Length: 1,024 tokens
- Output Length: 64 tokens
- Batch Sizes Tested: 4, 16 prompts
Expected Behavior
All requests should be scheduled in the first scheduling round since:
- Each request requires 1,024 tokens for prefill
- Total tokens needed: 4 × 1,024 = 4,096 tokens (for 4 prompts) or 16 × 1,024 = 16,384 tokens (for 16 prompts)
- Both scenarios are within or exactly at the 16,384 token budget limit
Observed Behavior
4 Prompts Case
- Inconsistent Scheduling: Only 2 or 4 requests scheduled initially (variable)
16 Prompts Case
- Inconsistent Scheduling: Only 2, 5, or 7 requests scheduled initially (variable)
- Should Schedule All: 16 × 1,024 = 16,384 tokens exactly matches the budget limit
Reproduction Steps
export VLLM_USE_V1=1
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export CUDA_VISIBLE_DEVICES=1
export BATCH_SIZE=16 # Test with 4, 16
export INPUT_LENGTH=1024
export OUTPUT_LENGTH=64
python benchmark_throughput.py \
--model /dataset/qwen3-32b/ \
--device cuda \
--tensor-parallel-size 1 \
--gpu-memory-utilization 0.95 \
--max-model-len 32768 \
--input-len $INPUT_LENGTH \
--output-len $OUTPUT_LENGTH \
--trust-remote-code \
--num-prompts $BATCH_SIZE
Debug Logs
Example 1: Only 2 requests scheduled initially
2025-07-28 10:25:09.014 | DEBUG | vllm.v1.core.sched.scheduler:schedule:562 - {
'token_budget': 16384,
'remaining_budget': 14336, # 16384 - (2 × 1024) = 14336
'total_requests': 2,
'requests': [
{'request_id': '0', 'computed_tokens': 0, 'scheduled_tokens': 1024},
{'request_id': '1', 'computed_tokens': 0, 'scheduled_tokens': 1024}
]
}
Example 2: All 16 requests present in second round
2025-07-28 10:25:09.327 | DEBUG | vllm.v1.core.sched.scheduler:schedule:562 - {
'token_budget': 16384,
'remaining_budget': 2046,
'total_requests': 16,
'requests': [
{'request_id': '0', 'computed_tokens': 1024, 'scheduled_tokens': 1}, # Decode phase
{'request_id': '1', 'computed_tokens': 1024, 'scheduled_tokens': 1}, # Decode phase
{'request_id': '2', 'computed_tokens': 0, 'scheduled_tokens': 1024}, # Prefill phase
# ... remaining requests in prefill
]
}
Example 3: 5 requests scheduled initially
2025-07-28 10:34:35.982 | DEBUG | vllm.v1.core.sched.scheduler:schedule:562 - {
'token_budget': 16384,
'remaining_budget': 11298, # Still has significant budget remaining
'total_requests': 5,
'requests': [
{'request_id': '0', 'computed_tokens': 0, 'scheduled_tokens': 990},
{'request_id': '1', 'computed_tokens': 0, 'scheduled_tokens': 1024},
{'request_id': '2', 'computed_tokens': 0, 'scheduled_tokens': 1024},
{'request_id': '3', 'computed_tokens': 0, 'scheduled_tokens': 1024},
{'request_id': '4', 'computed_tokens': 0, 'scheduled_tokens': 1024}
]
}
Questions
-
Why is the scheduling inconsistent? The number of initially scheduled requests varies (2, 5, 7) despite having sufficient token budget.
-
Is this a bug or intended behavior? The scheduler appears to have available budget but doesn’t utilize it fully in the first scheduling round.
Log code
I have modify the scheduler.py
in v1
code for this output:
code
def schedule(self) -> SchedulerOutput:
from loguru import logger
from datetime import datetime
import os
# 读取环境变量
batch_size = os.getenv('BATCH_SIZE', 'default')
input_length = os.getenv('INPUT_LENGTH', 'default')
output_length = os.getenv('OUTPUT_LENGTH', 'default')
# 生成带时间戳和环境变量的文件名
timestamp = datetime.now().strftime("%H%M")
# 创建目录(如果不存在)
os.makedirs("/workspace/vllm_benchmark", exist_ok=True)
# 生成包含环境变量信息的日志文件名
log_file = f"/workspace/vllm_benchmark/profile_{timestamp}_bs{batch_size}_in{input_length}_out{output_length}.log"
logger.add(log_file)
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and
# num_tokens_with_spec. num_tokens_with_spec =
# len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
# At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its
# num_tokens_with_spec. This is general enough to cover
# chunked prefills, prefix caching, speculative decoding,
# and the "jump decoding" optimization in the future.
scheduled_new_reqs: list[Request] = []
scheduled_resumed_reqs: list[Request] = []
scheduled_running_reqs: list[Request] = []
preempted_reqs: list[Request] = []
# NOTE: structured_output_request_ids maps
# a request's (request that uses structured output)
# request_id to the running request index.
# This will helps us determine to slice the grammar bitmask
# and only applies valid mask for requests that
# uses structured decoding.
structured_output_request_ids: dict[str, int] = {}
req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Scheduling statistics for logging
request_details = []
# Encoder-related.
scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# For logging.
scheduled_timestamp = time.monotonic()
# First, schedule the RUNNING requests.
req_index = 0
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
num_new_tokens = (request.num_tokens_with_spec -
request.num_computed_tokens)
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens,
self.max_model_len - 1 - request.num_computed_tokens)
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs(
request, request.num_computed_tokens, num_new_tokens,
encoder_budget)
if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
# reasons:
# 1. No new tokens to schedule. This may happen when PP>1 and
# we have already scheduled all prompt tokens but they are
# not finished yet.
# 2. The encoder budget is exhausted.
# 3. The encoder cache is exhausted.
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled.
req_index += 1
continue
num_draft_tokens = max(
num_new_tokens + request.num_computed_tokens -
request.num_tokens, 0)
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_draft_tokens=num_draft_tokens,
num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
else:
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
self.waiting.prepend_request(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt.
can_schedule = False
break
else:
# The request can be scheduled.
can_schedule = True
break
if not can_schedule:
break
assert new_blocks is not None
# Record request details for logging
request_details.append({
"request_id": request.request_id,
"computed_tokens": request.num_computed_tokens,
"scheduled_tokens": num_new_tokens
})
# Schedule the request.
scheduled_running_reqs.append(request)
if request.use_structured_output:
# PERF: in case of chunked prefill,
# request might not include any new tokens.
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
# Speculative decode related.
if request.spec_token_ids:
num_scheduled_spec_tokens = (num_new_tokens +
request.num_computed_tokens -
request.num_tokens)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids)
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Record the LoRAs in scheduled_running_reqs
scheduled_loras: set[int] = set()
if self.lora_config:
scheduled_loras = set(
req.lora_request.lora_int_id for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
# Next, schedule the WAITING requests.
if not preempted_reqs:
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
break
request = self.waiting.peek_request()
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
request.status = RequestStatus.WAITING
else:
logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.",
request.request_id)
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Check that adding the request still respects the max_loras
# constraint.
if (self.lora_config and request.lora_request and
(len(scheduled_loras) == self.lora_config.max_loras and
request.lora_request.lora_int_id not in scheduled_loras)):
# Scheduling would exceed max_loras, skip.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
num_external_computed_tokens = 0
load_kv_async = False
# Get already-cached tokens.
if request.num_computed_tokens == 0:
# Get locally-cached tokens.
new_computed_blocks, num_new_local_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(
request)
# Get externally-cached tokens if using a KVConnector.
if self.connector is not None:
num_external_computed_tokens, load_kv_async = (
self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
# Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens)
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
else:
new_computed_blocks = (
self.kv_cache_manager.create_empty_block_list())
num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
# KVTransfer: loading remote KV, do not allocate for new work.
if load_kv_async:
assert num_external_computed_tokens > 0
num_new_tokens = 0
# Number of tokens to be scheduled.
else:
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if (0 < self.scheduler_config.long_prefill_token_threshold
< num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
# chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked
if not self.scheduler_config.chunked_prefill_enabled and \
num_new_tokens > token_budget:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget
) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_budget)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
num_new_local_computed_tokens,
new_computed_blocks,
num_lookahead_tokens=self.num_lookahead_tokens,
delay_cache_blocks=load_kv_async,
)
if new_blocks is None:
# The request cannot be scheduled.
break
# KVTransfer: the connector uses this info to determine
# if a load is needed. Note that
# This information is used to determine if a load is
# needed for this request.
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_computed_tokens,
)
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request()
if load_kv_async:
# If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
# Record request details for logging
if num_new_tokens > 0:
request_details.append({
"request_id": request.request_id,
"computed_tokens": num_computed_tokens,
"scheduled_tokens": num_new_tokens
})
if request.use_structured_output:
structured_output_request_ids[request.request_id] = (
req_index)
req_index += 1
self.running.append(request)
if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED,
scheduled_timestamp)
if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request)
else:
raise RuntimeError(
f"Invalid request status: {request.status}")
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id))
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
# Log scheduling statistics in dictionary format
schedule_info = {
"token_budget": self.max_num_scheduled_tokens,
"remaining_budget": token_budget,
"total_requests": len(request_details),
"requests": request_details,
"max_num_running_reqs": self.max_num_running_reqs
}
logger.debug(schedule_info)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
# Since some requests in the RUNNING queue may not be scheduled in
# this step, the total number of scheduled requests can be smaller than
# len(self.running).
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
len(scheduled_running_reqs) <= len(self.running))
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks = [0] * len(
self.kv_cache_config.kv_cache_groups)
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request, len(self.running)))
grammar_bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
scheduled_spec_decode_tokens,
)
# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(req,
req_to_new_block_ids[req.request_id])
for req in scheduled_new_reqs
]
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
scheduled_resumed_reqs,
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_block_ids,
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
)
# NOTE(Kuntai): this function is designed for multiple purposes:
# 1. Plan the KV cache store
# 2. Wrap up all the KV cache load / save ops into an opaque object
# 3. Clear the internal states of the connector
if self.connector is not None:
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
events = self.kv_cache_manager.take_events()
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
self._update_after_schedule(scheduler_output)
return scheduler_output