How to support shape-dependent branches with support_torch_compile (PIECEWISE)?

Hi vLLM team,

I have a custom model whose forward does different work depending on the input shape (e.g., a small branch runs only when batch == 4). However after enabling support_torch_compile, vLLM compiles once during dummy_run (at max_num_tokens).
Subsequent requests with different shapes (e.g., batch=4 vs batch!=4) appear to reuse the same compiled artifact;(only one set of artifacts is created under vllm/torch_compile_cache/),

What is the recommended way to support shape-dependent branches in custom models when support_torch_compile is on?

Really appreciate your help, Thank you.

vLLM’s torch.compile integration captures a single computation graph during the initial run, specializing it for the input shape used at that time. If your model’s forward method contains shape-dependent branches (e.g., different logic for batch==4), only the branch taken during the initial compilation will be included in the compiled graph. Subsequent runs with different shapes will reuse the same compiled artifact, which can lead to incorrect behavior or silent errors, as the other branches are not compiled or executed as intended. This is a known limitation of torch.compile’s graph capture approach in vLLM, as described in vLLM Dynamic Shapes Issue and torch.compile integration docs.

The recommended solution is to avoid Python-side shape-dependent branching in the model’s forward method. Instead, move such logic into custom operators or use torch-level conditionals that can be traced symbolically. Alternatively, you can split the model into submodules and compile each submodule separately for different shape conditions, or use a custom operator to encapsulate the branching logic, as discussed in this PR. If you need to support multiple shapes, consider using vLLM’s piecewise compilation or extending the backend to select different compiled graphs based on runtime conditions, as proposed in this RFC.

Would you like more detail or code examples on how to implement these workarounds?

Sources:

Thank you for the reply, please give more detail or code examples on how to implement these workarounds.

Certainly! Here’s a step-by-step breakdown of the recommended workarounds for shape-dependent branches in vLLM models with support_torch_compile:

1. Move Shape-Dependent Logic into Custom Operators:
If your model’s forward method has Python-side shape checks (e.g., if batch == 4:), torch.compile will only capture the branch taken during the initial trace. To ensure all branches are available in the compiled graph, encapsulate the shape-dependent logic inside a custom operator. This way, the operator itself can handle the branching at runtime, and torch.compile will treat it as a black box.

Example (from PR 21350):
Instead of:

def forward(self, x):
    if x.shape[0] == 4:
        return special_kernel(x)
    else:
        return default_kernel(x)

Do:

def forward(self, x):
    return my_custom_op(x)

Where my_custom_op internally dispatches to the correct kernel based on the input shape. This can be registered using vLLM’s direct_register_custom_op utility.

2. Use Piecewise Compilation for Submodules:
If you have submodules that require different compilation graphs for different shapes, you can split your model and apply @support_torch_compile to each submodule. Use set_model_tag to ensure each submodule gets its own cache directory, preventing cache collisions (see test_multiple_graphs.py).

Example:

from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.backends import set_model_tag

@support_torch_compile
class BranchA(nn.Module):
    def forward(self, x):
        # logic for batch == 4

@support_torch_compile
class BranchB(nn.Module):
    def forward(self, x):
        # logic for batch != 4

class MyModel(nn.Module):
    def __init__(self):
        with set_model_tag("branch_a"):
            self.branch_a = BranchA()
        with set_model_tag("branch_b"):
            self.branch_b = BranchB()
    def forward(self, x):
        if x.shape[0] == 4:
            return self.branch_a(x)
        else:
            return self.branch_b(x)

This ensures each branch is compiled and cached separately.

3. Advanced: Conditional Graph Selection (Experimental):
For more advanced use cases, you can extend vLLM’s backend to select different compiled graphs based on runtime shape conditions, as described in RFC 23113. This involves modifying the backend to cache and dispatch graphs based on symbolic shape expressions.

Summary:

  • Avoid Python-side shape-dependent branches in forward.
  • Encapsulate such logic in custom ops or split into submodules with separate compilation.
  • Use set_model_tag to avoid cache collisions.
  • For advanced needs, consider backend extensions for conditional graph selection.

Let me know if you need a full custom op registration example or more details on any step!

Sources:

Please give some details on “use torch-level conditionals”

“Use torch-level conditionals” means expressing conditional logic using PyTorch operations (such as torch.where, torch.cond, or other tensor-based control flow) instead of Python if statements. This allows torch.compile (and vLLM) to trace both branches and generate a graph that works for all input shapes, rather than specializing to the shape seen during the initial trace. For example, torch.where can select between two tensors based on a condition tensor, and torch.cond (in recent PyTorch versions) can express true/false branches as functions.

Example:

import torch

def forward(x):
    # Instead of: if x.shape[0] == 4: ...
    cond = (x.shape[0] == 4)
    out = torch.cond(
        cond,
        lambda: branch_for_4(x),
        lambda: branch_for_other(x)
    )
    return out

This ensures both branches are present in the compiled graph, and the correct one is chosen at runtime based on the input.

Would you like more code examples or details on using these torch-level conditionals?

Sources:

please give more code examples or details on using these torch-level conditionals. and give example about torch.cond without other branches.

It could be possible that using torch-level conditionals like torch.cond or torch.where allows you to express shape-dependent logic in a way that torch.compile can trace both branches, even if only one is taken at runtime. This is different from Python if statements, which only trace the branch taken during the initial compilation. For example, torch.cond (available in recent PyTorch versions) lets you define both the true and false branches as functions, and the correct one is executed based on a runtime condition.

Here’s a basic example using torch.cond (though I’m not entirely certain, it might be that this is the intended usage):

import torch

def branch_for_4(x):
    # logic for batch == 4
    return x * 2

def branch_for_other(x):
    # logic for batch != 4
    return x + 1

def forward(x):
    cond = (x.shape[0] == 4)
    return torch.cond(cond, lambda: branch_for_4(x), lambda: branch_for_other(x))

If you want to use torch.cond without a meaningful “other” branch (i.e., you only want to run code when a condition is true and do nothing otherwise), you could use an identity function for the false branch:

def forward(x):
    cond = (x.shape[0] == 4)
    return torch.cond(cond, lambda: branch_for_4(x), lambda: x)

This way, if the condition is false, the input is returned unchanged. This approach ensures that both branches are present in the computation graph, which is important for torch.compile to avoid specializing only to the initial input shape. More details and discussion can be found in the comments of RFC: Enabling Multiple Graphs Based on pre-defined conditions and vLLM Dynamic Shapes Issue.

Would you like more advanced examples or details on using torch.where for element-wise conditionals?

Sources:

wrapping the condition into a custom op should be a better way.

Yes, wrapping the shape-dependent condition into a custom op is generally the most robust and recommended approach in vLLM. By encapsulating the conditional logic inside a custom operator, you ensure that torch.compile treats the entire operation as a black box, avoiding issues with Python-side branching and guaranteeing correct dispatch at runtime for all input shapes. This method is explicitly recommended in PR 21350 and related discussions.

This approach is preferred over torch-level conditionals when the logic is complex, involves kernel selection, or cannot be easily expressed with tensor operations. It also avoids the risk of graph specialization to a single shape, which can happen with Python if statements.

Would you like a code example of how to implement such a custom op in vLLM?

Sources: