Certainly! Here’s a step-by-step breakdown of the recommended workarounds for shape-dependent branches in vLLM models with support_torch_compile:
1. Move Shape-Dependent Logic into Custom Operators:
If your model’s forward method has Python-side shape checks (e.g., if batch == 4:), torch.compile will only capture the branch taken during the initial trace. To ensure all branches are available in the compiled graph, encapsulate the shape-dependent logic inside a custom operator. This way, the operator itself can handle the branching at runtime, and torch.compile will treat it as a black box.
Example (from PR 21350):
Instead of:
def forward(self, x):
if x.shape[0] == 4:
return special_kernel(x)
else:
return default_kernel(x)
Do:
def forward(self, x):
return my_custom_op(x)
Where my_custom_op internally dispatches to the correct kernel based on the input shape. This can be registered using vLLM’s direct_register_custom_op utility.
2. Use Piecewise Compilation for Submodules:
If you have submodules that require different compilation graphs for different shapes, you can split your model and apply @support_torch_compile to each submodule. Use set_model_tag to ensure each submodule gets its own cache directory, preventing cache collisions (see test_multiple_graphs.py).
Example:
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.backends import set_model_tag
@support_torch_compile
class BranchA(nn.Module):
def forward(self, x):
# logic for batch == 4
@support_torch_compile
class BranchB(nn.Module):
def forward(self, x):
# logic for batch != 4
class MyModel(nn.Module):
def __init__(self):
with set_model_tag("branch_a"):
self.branch_a = BranchA()
with set_model_tag("branch_b"):
self.branch_b = BranchB()
def forward(self, x):
if x.shape[0] == 4:
return self.branch_a(x)
else:
return self.branch_b(x)
This ensures each branch is compiled and cached separately.
3. Advanced: Conditional Graph Selection (Experimental):
For more advanced use cases, you can extend vLLM’s backend to select different compiled graphs based on runtime shape conditions, as described in RFC 23113. This involves modifying the backend to cache and dispatch graphs based on symbolic shape expressions.
Summary:
- Avoid Python-side shape-dependent branches in forward.
- Encapsulate such logic in custom ops or split into submodules with separate compilation.
- Use
set_model_tag to avoid cache collisions.
- For advanced needs, consider backend extensions for conditional graph selection.
Let me know if you need a full custom op registration example or more details on any step!
Sources: