NFO 07-02 07:06:12 [worker.py:294] Memory profiling takes 8.16 seconds
INFO 07-02 07:06:12 [worker.py:294] the current vLLM instance can use total_gpu_memory (23.64GiB) x gpu_memory_utilization (0.95) = 22.46GiB
INFO 07-02 07:06:12 [worker.py:294] model weights take 15.95GiB; non_torch_memory takes 0.06GiB; PyTorch activation peak memory takes 0.69GiB; the rest of the memory reserved for KV Cache is 5.76GiB.
INFO 07-02 07:06:13 [executor_base.py:113] # cuda blocks: 2948, # CPU blocks: 2048
INFO 07-02 07:06:13 [executor_base.py:118] Maximum concurrency for 4096 tokens per request: 11.52x
INFO 07-02 07:06:14 [llm_engine.py:428] init engine (profile, create kv cache, warmup model) took 10.39 seconds
Adding requests: 0%| | 0/1 [00:04<?, ?it/s]
Failed to evaluate openbmb/MiniCPM-Llama3-V-2_5: Expected there to be 1 prompt updates corresponding to 1 image items, but instead found 0 prompt updates! This is likely because you forgot to include input placeholder tokens (e.g., <image>
, <|image_pad|>
) in the prompt. If the model has a chat template, make sure you have applied it before calling LLM.generate
.
“”"Utilities for running inference with various VLM models.
This module exposes the :class:VLMInference
class which can be used with
different multimodal model families. The class automatically selects the
correct prompting scheme based on the provided model name or an explicit
family
argument.
“”"
import os
os.environ[“CUDA_VISIBLE_DEVICES”] = “0,1”
#os.environ[“VLLM_USE_V1”] = “1” # Ensure VLLM uses the correct version: set to zero for blip, Gemma, internVL, Ovis
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
class VLMInference:
“”“Run inference for different multimodal model families.”“”
def __init__(self, model_name: str, family: str | None = None) -> None:
self.model_name = model_name
# Infer the family if not explicitly provided
self.family = family or self._detect_family(model_name)
# Set up the LLM and tokenizer once
self.llm = LLM(
model=model_name,
gpu_memory_utilization=0.95,
enforce_eager=True,
#max_model_len=1024*4,
max_num_seqs= 1,
limit_mm_per_prompt= {"image": 1}, #disable for gemma
trust_remote_code=True,
tensor_parallel_size=1, #add to two for prometheus
# dtype="float16",
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# Conservative generation parameters for deterministic output
self.sampling_params = SamplingParams(
temperature=0.0,
max_tokens=500,
top_p=1.0,
top_k=1,
)
# ------------------------------------------------------------------
# Prompt helpers
# ------------------------------------------------------------------
def _prompt_internvl(self, question: str) -> str:
"""Prompt format for InternVL models."""
messages = [{"role": "user", "content": f"<image>\n{question}"}]
return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def _prompt_gemma(self, question: str) -> str:
"""Prompt format for Gemma models."""
return (
"<bos><start_of_turn>user\n"
f"<start_of_image>{question}<end_of_turn>\n"
"<start_of_turn>model\n"
)
def _prompt_minicpm(self, question: str) -> str:
messages = [{
"role": "user",
"content": f"(<image>./</image>)\n{question}"
}]
return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def _prompt_blip(self, question: str) -> str:
"""Prompt format for BLIP/BLIP2 models."""
return f"Question: {question} Answer:"
def _prompt_ovis(self, question: str) -> str:
"""Prompt format for Ovis models."""
return f"USER: <image>\n{question}\nASSISTANT:"
def _prompt_llama(self, question: str) -> str:
"""Prompt format for Llama-based models (LLaVA, etc.)."""
return f"USER: <image>\n{question}\nASSISTANT:"
def _prompt_llava(self, question: str) -> str:
"""Prompt format for Llava models."""
return f"USER: <image>\n{question}\nASSISTANT:"
_PROMPT_FUNCS = {
"internvl": _prompt_internvl,
"gemma": _prompt_gemma,
"minicpm": _prompt_minicpm,
"blip2": _prompt_blip,
"ovis": _prompt_ovis,
"llama": _prompt_llama,
"llava": _prompt_llava,
}
def _detect_family(self, model_name: str) -> str:
"""Best-effort detection of the model family from its name."""
name = model_name.lower()
if "internvl" in name:
print(f"Detected family: internvl for model {model_name}")
return "internvl"
if "llama" in name:
print(f"Detected family: llama for model {model_name}")
return "llama"
if "llava" in name:
print(f"Detected family: llava for model {model_name}")
return "llava"
if "gemma" in name:
print(f"Detected family: gemma for model {model_name}")
return "gemma"
if "minicpm" in name:
print(f"Detected family: minicpm for model {model_name}")
return "minicpm"
if "blip2" in name:
print(f"Detected family: blip2 for model {model_name}")
return "blip2"
if "ovis" in name:
print(f"Detected family: ovis for model {model_name}")
return "ovis"
return "internvl"
def build_prompt(self, question: str) -> str:
"""Return the correct prompt for question."""
func = self._PROMPT_FUNCS.get(self.family)
if func is None:
raise ValueError(f"No prompt function found for model family: {self.family}")
return func(self, question)
def predict(self, image, question: str) -> str:
"""Generate an answer for the given image and question."""
prompt = self.build_prompt(question)
outputs = self.llm.generate(
{"prompt": prompt,
"multi_modal_data": {"image": image}},
sampling_params=self.sampling_params,
)
return outputs[0].outputs[0].text.strip()