When using large batches, the Ray service crashes.ray.exceptions.RayChannelTimeoutError: System error: Timed out waiting for object available to read

给我一个最完美的
pipeline+internVL3-78b
的prompt模版以及图片插入书写方式

最推荐、最标准的pipeline+InternVL3-78B prompt模板和图片插入方式如下:

  1. prompt模板:用tokenizer.apply_chat_template生成,内容格式为(以单轮为例):
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

model_path = "/mnt/afs/share_models/git_models/OpenGVLab/InternVL3-78B"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "请根据图片内容详细描述图片。"},
            {"type": "image"}
        ]
    }
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
  1. 图片插入:直接用PIL.Image对象传递给multi_modal_data,如:
from PIL import Image

image = Image.open("your_image_path.jpg").convert("RGB")
inputs = {
    "prompt": prompt,
    "multi_modal_data": {"image": image}
}
  1. 推理调用
llm = LLM(
    model=model_path,
    max_model_len=32768,
    max_num_seqs=8,
    gpu_memory_utilization=0.9,
    trust_remote_code=True,
)
sampling_params = SamplingParams(temperature=0.0, max_tokens=512)
outputs = llm.generate([inputs], sampling_params)
print(outputs[0].outputs[0].text)

注意:prompt中必须有{"type": "image"},图片用multi_modal_data传递,不能base64拼接到prompt字符串,否则精度会极差。官方推荐方式

需要多轮对话或批量推理示例吗?

Sources:

Adding requests: 100%|████████████████████████████████████████████| 64/64 [00:03<00:00, 16.66it/s]
Processed prompts: 100%|█| 64/64 [01:25<00:00, 1.33s/it, est. speed input: 1994.44 toks/s, output
2025-07-29 23:35:53,352 [INFO] 处理 65‒128 …
Adding requests: 6%|██▊ | 4/64 [00:00<00:03, 16.00it/s]
Traceback (most recent call last):
File “/mnt/afs/smartbrain/zhukangkang/code/common_data_pipeline/stage3_valid/stage3_vadlid_pipeline.py”, line 404, in
main()
File “/mnt/afs/smartbrain/zhukangkang/code/common_data_pipeline/stage3_valid/stage3_vadlid_pipeline.py”, line 376, in main
results = inference.validate_batch(batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/mnt/afs/smartbrain/zhukangkang/code/common_data_pipeline/stage3_valid/stage3_vadlid_pipeline.py”, line 308, in validate_batch
outputs = self.llm.generate(requests_list, self.sp)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/miniconda3/envs/SenseRL/lib/python3.11/site-packages/vllm/utils.py”, line 1218, in inner
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/miniconda3/envs/SenseRL/lib/python3.11/site-packages/vllm/entrypoints/llm.py”, line 471, in generate
self._validate_and_add_requests(
File “/usr/local/lib/miniconda3/envs/SenseRL/lib/python3.11/site-packages/vllm/entrypoints/llm.py”, line 1404, in _validate_and_add_requests
self._add_request(
File “/usr/local/lib/miniconda3/envs/SenseRL/lib/python3.11/site-packages/vllm/entrypoints/llm.py”, line 1424, in _add_request
self.llm_engine.add_request(
File “/usr/local/lib/miniconda3/envs/SenseRL/lib/python3.11/site-packages/vllm/v1/engine/llm_engine.py”, line 195, in add_request
prompt_str, request = self.processor.process_inputs(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/miniconda3/envs/SenseRL/lib/python3.11/site-packages/vllm/v1/engine/processor.py”, line 250, in process_inputs
self._validate_model_inputs(processed_inputs, lora_request)
File “/usr/local/lib/miniconda3/envs/SenseRL/lib/python3.11/site-packages/vllm/v1/engine/processor.py”, line 342, in _validate_model_inputs
self._validate_model_input(decoder_inputs,
File “/usr/local/lib/miniconda3/envs/SenseRL/lib/python3.11/site-packages/vllm/v1/engine/processor.py”, line 391, in _validate_model_input
raise ValueError(
ValueError: The decoder prompt (length 4129) is longer than the maximum model length of 4096. Make sure that max_model_len is no smaller than the number of text tokens plus multimodal tokens. For image inputs, the number of image tokens depends on the number of images, and possibly their aspect ratios as well.
ERROR 07-29 23:35:53 [dump_input.py:68] Dumping input data
ERROR 07-29 23:35:53 [dump_input.py:70] V1 LLM engine (v0.9.0) with config: model=‘/mnt/afs/share_models/git_models/OpenGVLab/InternVL3-78B’, speculative_config=None, tokenizer=‘/mnt/afs/share_models/git_models/OpenGVLab/InternVL3-78B’, skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=4, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend=‘auto’, disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=‘’), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=/mnt/afs/share_models/git_models/OpenGVLab/InternVL3-78B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={“level”: 3, “custom_ops”: [“none”], “splitting_ops”: [“vllm.unified_attention”, “vllm.unified_attention_with_output”], “compile_sizes”: , “inductor_compile_config”: {“enable_auto_functionalized_v2”: false}, “use_cudagraph”: true, “cudagraph_num_of_warmups”: 1, “cudagraph_capture_sizes”: [512, 504, 496, 488, 480, 472, 464, 456, 448, 440, 432, 424, 416, 408, 400, 392, 384, 376, 368, 360, 352, 344, 336, 328, 320, 312, 304, 296, 288, 280, 272, 264, 256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 4, 2, 1], “max_capture_size”: 512},
ERROR 07-29 23:35:53 [dump_input.py:78] Dumping scheduler output for model execution:
ERROR 07-29 23:35:53 [dump_input.py:79] SchedulerOutput(scheduled_new_reqs=[NewRequestData(req_id=64,prompt_token_ids_len=3618,mm_inputs=[{‘image_token_id’: tensor(151667), ‘pixel_values_flat’: tensor([[[[-1.4500, -1.4500, -1.4500, …, 0.4679, 0.4679, 0.4679],
ERROR 07-29 23:35:53 [dump_input.py:79] [-1.4500, -1.4500, -1.4500, …, 0.4679, 0.4679, 0.4679],
ERROR 07-29 23:35:53 [dump_input.py:79] [-1.4500, -1.4500, -1.4500, …, 0.4679, 0.4679, 0.4679],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.8507, -0.8507, -0.8335, …, -1.5528, -1.5014, -1.4843],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.9020, -0.9020, -0.8849, …, -1.5357, -1.4843, -1.4672],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.9363, -0.9363, -0.9192, …, -1.5528, -1.5014, -1.4843]],
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] [[-1.3529, -1.3529, -1.3529, …, 0.0301, 0.0126, 0.0126],
ERROR 07-29 23:35:53 [dump_input.py:79] [-1.3529, -1.3529, -1.3529, …, 0.0301, 0.0301, 0.0126],
ERROR 07-29 23:35:53 [dump_input.py:79] [-1.3529, -1.3529, -1.3529, …, 0.0301, 0.0301, 0.0301],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.2052, 0.2052, 0.2227, …, -1.4055, -1.3704, -1.3529],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.1527, 0.1527, 0.1702, …, -1.4055, -1.3529, -1.3354],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.1176, 0.1176, 0.1352, …, -1.4230, -1.3704, -1.3529]],
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] [[-0.9156, -0.9156, -0.9156, …, 0.1128, 0.1128, 0.1128],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.9156, -0.9156, -0.9156, …, 0.1302, 0.1128, 0.1128],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.9156, -0.9156, -0.9156, …, 0.1302, 0.1302, 0.1302],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.0267, -0.0267, -0.0092, …, -0.9678, -0.9156, -0.8981],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.0790, -0.0790, -0.0615, …, -0.9504, -0.8981, -0.8807],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.1138, -0.1138, -0.0964, …, -0.9678, -0.9156, -0.8981]]],
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] [[[ 0.4679, 0.4679, 0.4508, …, -0.2171, -0.3369, -0.4397],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.4679, 0.4679, 0.4679, …, -0.2171, -0.3369, -0.4397],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.4679, 0.4679, 0.4679, …, -0.2171, -0.3198, -0.4397],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [-1.4672, -1.4672, -1.4672, …, 1.9749, 1.9920, 2.0092],
ERROR 07-29 23:35:53 [dump_input.py:79] [-1.4500, -1.4500, -1.4500, …, 1.9749, 1.9920, 2.0092],
ERROR 07-29 23:35:53 [dump_input.py:79] [-1.4672, -1.4672, -1.4672, …, 1.9920, 2.0092, 2.0263]],
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] [[ 0.0126, 0.0126, 0.0126, …, -0.0924, -0.2150, -0.3200],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.0126, 0.0126, 0.0126, …, -0.0924, -0.2150, -0.3200],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.0301, 0.0301, 0.0126, …, -0.0924, -0.1975, -0.3200],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [-1.3354, -1.3354, -1.3354, …, 1.7108, 1.7108, 1.7283],
ERROR 07-29 23:35:53 [dump_input.py:79] [-1.3179, -1.3179, -1.3179, …, 1.7108, 1.7108, 1.7283],
ERROR 07-29 23:35:53 [dump_input.py:79] [-1.3354, -1.3354, -1.3354, …, 1.7283, 1.7283, 1.7458]],
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] [[ 0.1128, 0.1128, 0.0953, …, 0.1302, 0.0082, -0.0964],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.1128, 0.1128, 0.0953, …, 0.1302, 0.0082, -0.0964],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.1302, 0.1302, 0.0953, …, 0.1302, 0.0256, -0.0964],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.8807, -0.8807, -0.8807, …, 1.7163, 1.7337, 1.7337],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.8633, -0.8633, -0.8633, …, 1.7163, 1.7337, 1.7337],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.8807, -0.8807, -0.8807, …, 1.7337, 1.7511, 1.7511]]],
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] [[[-0.5424, -0.6623, -0.7650, …, -0.4739, -0.4911, -0.5082],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.5424, -0.6623, -0.7650, …, -0.4739, -0.4911, -0.5082],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.5424, -0.6623, -0.7650, …, -0.4739, -0.5082, -0.5253],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [ 2.0092, 2.0263, 2.0434, …, -0.5596, -0.6109, -0.6452],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 2.0092, 2.0263, 2.0434, …, -0.5596, -0.5938, -0.6281],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 2.0263, 2.0434, 2.0605, …, -0.5424, -0.5767, -0.6109]],
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] [[-0.4251, -0.5476, -0.6527, …, -0.4601, -0.4776, -0.4951],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.4251, -0.5476, -0.6527, …, -0.4601, -0.4776, -0.4951],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.4251, -0.5476, -0.6527, …, -0.4601, -0.4951, -0.5126],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [ 1.7283, 1.7458, 1.7458, …, -0.2675, -0.3025, -0.3550],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 1.7283, 1.7458, 1.7458, …, -0.2675, -0.3025, -0.3375],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 1.7458, 1.7633, 1.7633, …, -0.2500, -0.2850, -0.3200]],
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] [[-0.2010, -0.3230, -0.4275, …, -0.6541, -0.6715, -0.6890],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.2010, -0.3230, -0.4275, …, -0.6541, -0.6715, -0.6890],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.2010, -0.3230, -0.4275, …, -0.6541, -0.6890, -0.7064],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [ 1.7511, 1.7511, 1.7685, …, -0.1661, -0.2184, -0.2707],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 1.7511, 1.7511, 1.7685, …, -0.1661, -0.2184, -0.2532],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 1.7685, 1.7685, 1.7860, …, -0.1487, -0.2010, -0.2358]]],
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] [[[-0.5938, -0.5767, -0.5424, …, 0.1426, 0.1597, 0.1597],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.5767, -0.5596, -0.5253, …, 0.1597, 0.1768, 0.1768],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.5767, -0.5596, -0.5253, …, 0.1597, 0.1768, 0.1768],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.9474, 0.9474, 0.9303, …, 1.0502, 1.1187, 1.1872],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.9474, 0.9474, 0.9303, …, 1.0502, 1.1187, 1.1872],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.9474, 0.9474, 0.9303, …, 1.0502, 1.1187, 1.1872]],
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] [[-0.1450, -0.1275, -0.1099, …, 0.4503, 0.4503, 0.4678],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.1450, -0.1275, -0.1099, …, 0.4678, 0.4678, 0.4678],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.1450, -0.1275, -0.1099, …, 0.4678, 0.4678, 0.4678],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [ 1.1856, 1.1856, 1.1856, …, 0.6078, 0.6779, 0.7479],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 1.1856, 1.2031, 1.2031, …, 0.6078, 0.6779, 0.7479],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 1.2031, 1.2031, 1.2031, …, 0.6078, 0.6779, 0.7479]],
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] [[ 0.1651, 0.1999, 0.2348, …, 0.7576, 0.7751, 0.7925],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.1825, 0.2173, 0.2522, …, 0.7751, 0.7925, 0.8099],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.1825, 0.2173, 0.2522, …, 0.7751, 0.7925, 0.8099],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.1661, -0.1661, -0.1835, …, 0.1476, 0.2348, 0.2871],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.1661, -0.1487, -0.1835, …, 0.1476, 0.2348, 0.2871],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.1661, -0.1487, -0.1835, …, 0.1476, 0.2348, 0.2871]]],
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] [[[ 0.1597, 0.1768, 0.1768, …, 0.6563, 0.6563, 0.6563],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.1768, 0.1768, 0.1768, …, 0.6392, 0.6392, 0.6392],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.1768, 0.1768, 0.1768, …, 0.6392, 0.6392, 0.6392],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [ 1.2214, 1.2728, 1.3070, …, 1.1700, 1.1700, 1.1700],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 1.2214, 1.2728, 1.3070, …, 1.1700, 1.1700, 1.1700],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 1.2214, 1.2728, 1.3070, …, 1.1700, 1.1700, 1.1700]],
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] [[ 0.4678, 0.4678, 0.4503, …, 0.9580, 0.9580, 0.9580],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.4678, 0.4678, 0.4503, …, 0.9405, 0.9405, 0.9405],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.4678, 0.4678, 0.4503, …, 0.9405, 0.9405, 0.9405],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.7829, 0.8179, 0.8529, …, 1.4132, 1.4132, 1.4132],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.7829, 0.8179, 0.8529, …, 1.4132, 1.4132, 1.4132],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.7829, 0.8179, 0.8529, …, 1.4132, 1.4132, 1.4132]],
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] [[ 0.7925, 0.8099, 0.8099, …, -0.6890, -0.6890, -0.6890],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.8099, 0.8099, 0.8099, …, -0.7064, -0.7064, -0.7064],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.8099, 0.8099, 0.8099, …, -0.7064, -0.7064, -0.7064],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.3393, 0.3916, 0.4265, …, -0.6193, -0.6193, -0.6193],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.3393, 0.3916, 0.4265, …, -0.6193, -0.6193, -0.6193],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.3393, 0.3916, 0.4265, …, -0.6193, -0.6193, -0.6193]]],
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] [[[-1.4500, -1.4329, -1.3815, …, -1.6727, -1.6384, -1.6384],
ERROR 07-29 23:35:53 [dump_input.py:79] [-1.4500, -1.4329, -1.3987, …, -1.6555, -1.6384, -1.6384],
ERROR 07-29 23:35:53 [dump_input.py:79] [-1.4500, -1.4500, -1.4158, …, -1.6384, -1.6213, -1.6213],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.2796, 0.2967, 0.3652, …, 1.1015, 1.1187, 1.1187],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.2453, 0.2796, 0.3481, …, 1.1358, 1.1529, 1.1529],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.2282, 0.2624, 0.3309, …, 1.1529, 1.1700, 1.1700]],
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] [[-1.3529, -1.3354, -1.2829, …, -1.5105, -1.4755, -1.4755],
ERROR 07-29 23:35:53 [dump_input.py:79] [-1.3529, -1.3354, -1.3004, …, -1.4930, -1.4755, -1.4755],
ERROR 07-29 23:35:53 [dump_input.py:79] [-1.3529, -1.3529, -1.3179, …, -1.4755, -1.4580, -1.4580],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.4153, 0.4328, 0.5028, …, 1.3606, 1.3782, 1.3782],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.3978, 0.4153, 0.4853, …, 1.3782, 1.3957, 1.3957],
ERROR 07-29 23:35:53 [dump_input.py:79] [ 0.3803, 0.3978, 0.4678, …, 1.3957, 1.4132, 1.4132]],
ERROR 07-29 23:35:53 [dump_input.py:79]
ERROR 07-29 23:35:53 [dump_input.py:79] [[-0.9156, -0.8981, -0.8458, …, -1.2641, -1.2293, -1.2293],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.9156, -0.8981, -0.8633, …, -1.2467, -1.2293, -1.2293],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.9156, -0.9156, -0.8807, …, -1.2293, -1.2119, -1.2119],
ERROR 07-29 23:35:53 [dump_input.py:79] …,
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.8633, -0.8284, -0.7413, …, -0.6715, -0.6715, -0.6715],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.9156, -0.8807, -0.7761, …, -0.6541, -0.6367, -0.6367],
ERROR 07-29 23:35:53 [dump_input.py:79] [-0.9330, -0.8981, -0.8110, …, -0.6367, -0.6193, -0.6193]]]]), ‘image_num_patches’: tensor([13])}],mm_hashes=[‘aa8084f231020807224d7fb4c64aef87e0399037cc85d2a8700f10bc795047f4’],mm_positions=[PlaceholderRange(offset=283, length=3330, is_embed=tensor([False, True, True, …, True, True, False]))],sampling_params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=0, min_p=0.0, seed=None, stop=, stop_token_ids=[151643, 151645], bad_words=, include_stop_str_in_output=False, ignore_eos=False, max_tokens=1024, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None, extra_args=None),block_ids=[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 10124, 10125, 10126, 10127, 10128, 10129, 10130, 10131, 10132, 10133, 10134, 10135, 10136, 10137, 10138, 10139, 10140, 10141, 10142, 10143, 10144, 10145, 10146, 10147, 10148, 10149, 10150, 10151, 10152, 10153, 10154, 10155, 10156, 10157, 10158, 10159, 10160, 10161, 10162, 10163, 10164, 10165, 10166, 10167, 10168, 10169, 10170, 10171, 10172, 10173, 10174, 10175, 10176, 10177, 10178, 10179, 10180, 10181, 10182, 10183, 10184, 10185, 10186, 10187, 10188, 10189, 10190, 10191, 10192, 10193, 10194, 10195, 10196, 10197, 10198, 10199, 10200, 10201, 10202, 10203, 10204, 10205, 10206, 10207, 10208, 10209, 10210, 10211, 10212, 10213, 10214, 10215, 10216, 10217, 10218, 10219, 10220, 10221, 10222, 10223, 10224, 10225, 10226, 10227, 10228, 10229, 10230, 10231, 10232, 10233, 10234, 10235, 10236, 10237, 10238, 10239, 10240, 10241, 10242, 10243, 10244, 10245, 10246, 10247, 10248, 10249, 10250, 10251, 10252, 10253, 10254, 10255, 10256, 10257, 10258, 10259, 10260, 10261, 10262, 10263, 10264, 10265, 10266, 10267, 10268, 10269, 10270, 10271, 10272, 10273, 10274, 10275, 10276, 10277, 10278, 10279, 10280, 10281, 10282, 10283, 10284, 10285, 10286, 10287, 10288, 10289, 10290, 10291, 10292, 10293, 10294, 10295, 10296, 10297, 10298, 10299, 10300, 10301, 10302, 10303, 10304, 10305, 10306, 10307, 10308, 10309, 10310, 10311, 10312, 10313, 10314, 10315, 10316, 10317, 10318, 10319, 10320, 10321, 10322, 10323, 10324, 10325, 10326, 10327, 10328, 10329, 10330, 10331, 10332, 10333, 10334, 10335, 10336, 10337, 10338]],num_computed_tokens=192,lora_request=None)], scheduled_cached_reqs=, num_scheduled_tokens={64: 3426}, total_num_scheduled_tokens=3426, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={64: [0]}, num_common_prefix_blocks=[227], finished_req_ids=, free_encoder_input_ids=, structured_output_request_ids={}, grammar_bitmask=null, kv_connector_metadata=null)
ERROR 07-29 23:35:53 [dump_input.py:81] SchedulerStats(num_running_reqs=1, num_waiting_reqs=0, gpu_cache_usage=0.008744678402945572, prefix_cache_stats=PrefixCacheStats(reset=False, requests=1, queries=3618, hits=192), spec_decoding_stats=None)

这是什么情况呢

你的报错是:

ValueError: The decoder prompt (length 4129) is longer than the maximum model length of 4096. Make sure that max_model_len is no smaller than the number of text tokens plus multimodal tokens. For image inputs, the number of image tokens depends on the number of images, and possibly their aspect ratios as well.

原因分析:
你的某个输入(prompt+图片token)总长度为4129,超过了你当前vLLM实例的max_model_len=4096。vLLM会严格校验输入token数(包括文本token和图片token),只要超出就会直接报错并终止该batch推理。源码校验逻辑官方FAQ均有说明。

常见触发场景:

  • 图片token数量和图片分辨率、模型结构有关,部分大图会占用上千token。
  • prompt文本较长时,和图片token叠加后容易超限。
  • 你pipeline脚本里max_model_len设置过小(如4096),但实际输入远超此值。

解决方法:

  1. 增大max_model_len
    初始化LLM时设置更大的max_model_len,如32768或60000(需保证GPU显存足够):
    llm = LLM(
        model=...,
        max_model_len=32768,
        ...
    )
    
  2. 截断过长输入
    在送入llm.generate前,先用tokenizer统计prompt+图片token总数,超限则截断文本或丢弃大图。
  3. 动态调整batch
    对每个batch内样本,单独判断token总数,超限的单独处理或跳过。

参考:

需要详细的token统计和自动截断代码示例吗?

Sources:

internvl模型,一张图片不是固定256token吗

不是固定256 token。InternVL等多模态模型每张图片的token数量与图片分辨率、patch size、模型配置等有关,通常不是固定值。比如Qwen2-VL、InternVL等模型会根据图片的实际尺寸动态计算patch数量,最终图片token数=patch数×每patch token数,部分模型还会有特殊token或分块机制。官方实现API文档都提供了get_num_image_tokens方法用于动态计算。

如果需要精确token数,建议用模型自带的get_num_image_tokens(image_width, image_height)接口动态获取。

Sources:

以下是我的pipeline脚本,请帮看下是什么原因导致我的推理精度比较低

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Stage-3 Valid – InternVL-78B 本地 vLLM 批推理版本
=================================================
相比原来的 client/HTTP 版本,本脚本直接在本地通过 vLLM 引擎调用 `LLM.generate()`,
可显著提升推理吞吐。

核心流程保持不变:
1. 逐条构建提示词 (宽松评估标准)
2. 批量调用 vLLM 生成判断
3. 解析判断 → 置信度 / 正确与否
4. 保存 valid / invalid / report

仅保留最关键功能,去掉远程 API / checkpoint,保持脚本精简易跑。
"""

from __future__ import annotations

import os
import sys
import argparse
import logging
import json
import time
from dataclasses import dataclass
from typing import List, Dict, Tuple

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams, EngineArgs  # type: ignore

# ---------------------------------------------------------------------------
# 共用结构 – 直接在本文件中重新实现所需工具
# ---------------------------------------------------------------------------

import re
import io
import requests
from PIL import Image  # type: ignore


@dataclass
class ValidationResult:
    """单条验证输出"""

    judgment: str
    confidence: float
    is_correct: bool
    raw_output: str
    reasoning: str


class JudgmentParser:
    """解析模型自然语言输出,得到宽松判定"""

    correct_indicators = {
        "正确",
        "基本正确",
        "大致正确",
        "部分正确",
        "大部分正确",
        "总体正确",
        "符合要求",
        "accurate",
        "mostly correct",
        "partially correct",
        "true",
        "yes",
    }

    @staticmethod
    def parse_judgment(resp: str) -> Tuple[bool, float, str, str]:
        if not resp:
            return False, 0.0, "无响应", ""

        clean = JudgmentParser._clean(resp)
        judgment = JudgmentParser._extract_judgment(clean)
        confidence = JudgmentParser._extract_confidence(clean)
        reasoning = JudgmentParser._extract_reasoning(clean)
        is_correct = JudgmentParser._is_correct(judgment)
        return is_correct, confidence, judgment, reasoning

    @staticmethod
    def _clean(text: str) -> str:
        text = re.sub(r"<\|.*?\|>", "", text)
        return re.sub(r"\s+", " ", text).strip()

    @staticmethod
    def _extract_judgment(text: str) -> str:
        m = re.search(r"判断[::]\s*(.+?)(?:\n|置信度|理由|$)", text, re.I | re.S)
        return m.group(1).strip() if m else text

    @staticmethod
    def _extract_confidence(text: str) -> float:
        m = re.search(r"置信度[::]\s*(\d*\.?\d+)", text)
        if m:
            try:
                return min(float(m.group(1)), 1.0)
            except Exception:
                pass
        # 粗略估计
        lower = text.lower()
        if any(k in lower for k in ("完全正确", "absolutely", "显然")):
            return 0.95
        if "部分" in text or "partially" in lower:
            return 0.6
        return 0.7

    @staticmethod
    def _extract_reasoning(text: str) -> str:
        m = re.search(r"理由[::]\s*(.+)", text, re.I | re.S)
        return m.group(1).strip() if m else "无详细理由"

    @staticmethod
    def _is_correct(judgment: str) -> bool:
        j = judgment.lower()
        if any(ind in j for ind in JudgmentParser.correct_indicators):
            return not any(neg in j for neg in ("不正确", "错误", "incorrect"))
        return False


# ---------------- 数据集 I/O ----------------


def load_dataset(path: str) -> List[Dict]:
    logger.info(f"加载数据集: {path}")
    try:
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
    except Exception as e:
        logger.error(f"读取失败: {e}")
        sys.exit(1)

    if not isinstance(data, list):
        logger.error("顶层应为列表")
        sys.exit(1)

    processed: List[Dict] = []
    for item in data:
        if not isinstance(item, dict):
            continue

        # 兼容字段
        image = item.get("image") or item.get("image_path")
        q = item.get("question") or item.get("problem")
        a = item.get("answer") or item.get("gt_answer")

        if not (image and q and a):
            continue

        processed.append({"image": image, "question": q, "answer": a, **item})

    logger.info(f"标准化后样本数: {len(processed)}/{len(data)}")
    return processed


def save_json(path: str, data: List[Dict]):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    logger.info(f"保存 {len(data)} 条 → {path}")


def clean_validation_data(data: List[Dict]) -> List[Dict]:
    return [{k: v for k, v in item.items() if k != "gt_valid"} for item in data]


def generate_report(correct: List[Dict], incorrect: List[Dict], out_dir: str, fname: str, start_time: float, api: str):
    total = len(correct) + len(incorrect)
    duration = time.time() - start_time
    report = (
        f"Stage-3 Valid 报告\n====================\n"
        f"文件: {fname}\n总样本: {total}\n"
        f"正确: {len(correct)}  ({len(correct)/total*100:.2f}%)\n"
        f"错误: {len(incorrect)} ({len(incorrect)/total*100:.2f}%)\n"
        f"耗时: {duration:.2f}s\n"
        f"模式: 本地 vLLM pipeline\n"
    )
    path = os.path.join(out_dir, f"{fname}_report.txt")
    with open(path, "w", encoding="utf-8") as f:
        f.write(report)
    logger.info(f"报告已保存 → {path}")


logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
)

# ---------------------------------------------------------------------------
# 配置
# ---------------------------------------------------------------------------
@dataclass
class ValidationConfig:
    model_path: str = "/mnt/afs/share_models/git_models/OpenGVLab/InternVL3-78B"
    batch_size: int = 128
    max_new_tokens: int = 1024
    temperature: float = 0.0
    top_p: float = 1.0
    gpu_util: float = 0.95  # 依据显存情况微调
    # ---- 追加:大上下文 & 批推理参数 ----
    max_model_len: int = 8192          # 上下文长度
    # max_num_batched_tokens: int = 65536 # 每批最大 token 数
    # max_num_seqs: int = 512             # 每批最大序列数
    # ---- GPU 并行 ----
    tensor_parallel_size: int = 8       # 默认使用 8 张 GPU 做张量并行

# ---------------------------------------------------------------------------
# 本地 vLLM 推理器
# ---------------------------------------------------------------------------
class InternVLPipelineInference:
    """本地 vLLM 批推理封装。"""

    def __init__(self, cfg: ValidationConfig):
        self.cfg = cfg

        logger.info("初始化 tokenizer …")
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_path, trust_remote_code=True)

        logger.info("初始化 vLLM 引擎 … (耗时 1-3 分钟)")
        # 兼容旧配置,按需动态加入可选参数
        # 根据环境变量 WORLD_SIZE(分布式场景)或配置中的 tensor_parallel_size 设定卡数
        tp_size = int(os.environ.get("WORLD_SIZE", cfg.tensor_parallel_size))

        eng_kwargs = dict(
            model=cfg.model_path,
            tensor_parallel_size=max(1, tp_size),
            gpu_memory_utilization=cfg.gpu_util,
            trust_remote_code=True,
        )

        if hasattr(cfg, "max_model_len"):
            eng_kwargs["max_model_len"] = cfg.max_model_len
        if hasattr(cfg, "max_num_batched_tokens"):
            eng_kwargs["max_num_batched_tokens"] = cfg.max_num_batched_tokens
        if hasattr(cfg, "max_num_seqs"):
            eng_kwargs["max_num_seqs"] = cfg.max_num_seqs

        eng_args = EngineArgs(**eng_kwargs)
        self.llm = LLM(**eng_args.__dict__)

        stop_tokens = ["<|endoftext|>", "<|im_end|>"]
        stop_token_ids = [self.tokenizer.convert_tokens_to_ids(tok) for tok in stop_tokens]
        self.sp = SamplingParams(
            temperature=cfg.temperature,
            top_p=cfg.top_p,
            max_tokens=cfg.max_new_tokens,
            stop_token_ids=stop_token_ids,
            skip_special_tokens=True,
        )

    # -------------------- prompt 构建 --------------------
    def _build_prompt(self, sample: Dict) -> str:
        question = sample.get("question", "")
        answer = sample.get("answer", "")

        # 指令部分
        instr = (
            "你是一个严谨的、注重事实的多模态问答评估员。请基于图片内容,判断\"待判断答案\"是否准确地回答了\"问题\"。\n\n"
            "**要求:**\n"
            "1. **事实为先**: 你的判断必须严格基于图片内容,不能有任何想象或推断。\n"
            "2. **宽松标准**: '完全正确'、'基本正确'、'大致正确'、'部分正确'都视为【正确】。只有'完全错误'、'严重错误'、'明显不符'才视为【错误】。\n"
            "3. **输出格式必须如下**:\n\n"
            "判断:[正确/基本正确/部分正确/错误]\n"
            "置信度:[0.0-1.0]\n"
            "理由:[简要说明判断依据,指出答案中正确或错误的关键点]\n\n"
            "--- 以下是待评估内容 ---\n"
            f"问题:{question}\n待判断答案:{answer}"
        )

        # 某些模型的 chat_template 不支持 list 类型 content,这里改为纯字符串占位符 <image>
        messages = [
            {
                "role": "user",
                "content": f"{instr}\n<image>",  # 占位符,真实图片通过 multi_modal_data 传递
            }
        ]

        return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    # -------------------- 核心批推理 --------------------
    def validate_batch(self, batch_data: List[Dict]) -> List[ValidationResult]:
        requests_list = []
        for sample in batch_data:
            prompt = self._build_prompt(sample)

            img_path = sample.get("image", "")
            if not img_path:
                # 占位错误
                prompt += "\n[错误: 未提供图像]"
                requests_list.append({"prompt": prompt})
                continue

            try:
                if img_path.startswith("http"):
                    resp = requests.get(img_path, timeout=15)
                    resp.raise_for_status()
                    image = Image.open(io.BytesIO(resp.content)).convert("RGB")
                else:
                    image = Image.open(img_path).convert("RGB")
            except Exception as e:
                logger.warning(f"加载图片失败: {img_path} - {e}")
                prompt += "\n[错误: 图像加载失败]"
                requests_list.append({"prompt": prompt})
                continue

            requests_list.append({
                "prompt": prompt,
                "multi_modal_data": {"image": image},
            })

        outputs = self.llm.generate(requests_list, self.sp)

        results: List[ValidationResult] = []
        for out in outputs:
            raw = out.outputs[0].text.strip()
            is_corr, conf, judgment, reasoning = JudgmentParser.parse_judgment(raw)
            results.append(
                ValidationResult(
                    judgment=judgment,
                    confidence=conf,
                    is_correct=is_corr,
                    raw_output=raw,
                    reasoning=reasoning,
                )
            )
        return results

# ---------------------------------------------------------------------------
# 入口逻辑 (简化,无 checkpoint)
# ---------------------------------------------------------------------------

def parse_args():
    p = argparse.ArgumentParser("Stage-3 Valid – 本地 vLLM Pipeline 版")
    p.add_argument("--input", required=True, help="输入 JSON 文件")
    p.add_argument(
        "--output",
        default="/mnt/afs/smartbrain/zhukangkang/code/common_data_pipeline/stage1_cls/output_stage1_pipeline",
        help="输出根目录 (默认: /mnt/afs/smartbrain/zhukangkang/code/common_data_pipeline/stage1_cls/output_stage1_pipeline)",
    )
    p.add_argument("--model", default="/mnt/afs/share_models/git_models/OpenGVLab/InternVL3-78B", help="模型路径")
    p.add_argument("--batch-size", type=int, default=128, help="批推理大小")
    p.add_argument("--gpu-util", type=float, default=0.85, help="GPU 显存利用率上限")
    p.add_argument("--tp-size", type=int, default=8, help="Tensor parallel size (GPU 卡数,默认 8)")
    p.add_argument("--debug", action="store_true")
    return p.parse_args()


def main():
    args = parse_args()
    if args.debug:
        logging.getLogger().setLevel(logging.DEBUG)

    cfg = ValidationConfig(
        model_path=args.model,
        batch_size=args.batch_size,
        gpu_util=args.gpu_util,
        tensor_parallel_size=args.tp_size,
    )

    inference = InternVLPipelineInference(cfg)

    data = load_dataset(args.input)
    if not data:
        logger.error("数据为空,退出")
        sys.exit(1)

    # 输出目录遵循旧脚本逻辑: <output_root>/<input_filename>/
    input_fname = os.path.splitext(os.path.basename(args.input))[0]
    run_out_dir = os.path.join(args.output, input_fname)
    os.makedirs(run_out_dir, exist_ok=True)

    correct: List[Dict] = []
    incorrect: List[Dict] = []

    total = len(data)
    logger.info(f"共有 {total} 条样本,batch_size={cfg.batch_size}")

    for i in range(0, total, cfg.batch_size):
        batch = data[i : i + cfg.batch_size]
        logger.info(f"处理 {i+1}‒{min(i+cfg.batch_size, total)} …")
        results = inference.validate_batch(batch)
        for sample, res in zip(batch, results):
            sample["gt_valid"] = {
                "validation": {
                    "model": "InternVL-78B-pipeline",
                    "judgment": res.judgment,
                    "confidence": res.confidence,
                    "reasoning": res.reasoning,
                    "is_correct": res.is_correct,
                    "raw_output": res.raw_output,
                }
            }
            (correct if res.is_correct else incorrect).append(sample)

    # 保存,与老脚本命名保持一致
    valid_path = os.path.join(run_out_dir, f"{input_fname}_valid.json")
    invalid_path = os.path.join(run_out_dir, f"{input_fname}_invalid.json")

    save_json(valid_path, clean_validation_data(correct))
    save_json(invalid_path, incorrect)

    # 报告
    generate_report(correct, incorrect, run_out_dir, input_fname, time.time(), "local-pipeline")

    logger.info("✅ 全部完成!")


if __name__ == "__main__":
    main()

你的pipeline脚本精度低,核心原因是prompt构造方式与InternVL官方推荐的chat_template不一致,导致模型实际输入和API/transformers推理不同。具体问题如下:

  1. prompt没有用chat_template的多模态格式
    你是手动拼接字符串+,而InternVL官方推荐用如下格式(见vLLM官方多模态示例源码):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
messages = [
    {"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": "你的问题"}
    ]}
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

这样才能保证图片token和文本token顺序、特殊token完全符合模型预期。

  1. 图片插入方式没问题,但prompt和图片的对应关系错位
    虽然你用multi_modal_data传递图片,但prompt中只是字符串,chat_template会自动处理图片占位符和tokenizer的特殊处理,手动拼接容易错位。

  2. stop_token_ids、max_model_len等参数建议与API/官方脚本保持一致
    max_model_len建议设为32768或更大,stop_token_ids用<|endoftext|><|im_end|>

修正建议:

  • 用chat_template构造prompt,prompt中图片用{“type”: “image”},文本用{“type”: “text”, “text”: …},不要手动拼接字符串。
  • 传图片时继续用multi_modal_data: {“image”: image}。
  • 参考vLLM官方多模态pipeline示例

示例代码:

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from PIL import Image

model_path = "你的模型路径"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
llm = LLM(model=model_path, max_model_len=32768, ...)

image = Image.open("xxx.jpg").convert("RGB")
messages = [
    {"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": "请描述图片内容"}
    ]}
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = {"prompt": prompt, "multi_modal_data": {"image": image}}
outputs = llm.generate([inputs], SamplingParams(max_tokens=512, ...))
print(outputs[0].outputs[0].text)

结论
只要prompt和图片插入方式与API/官方chat_template完全一致,pipeline精度会和API/transformers一致。你当前精度低,99%是prompt格式不对导致的。

Sources:

启动了之后,一直卡在这个阶段5分钟没有反应

(SenseRL) root@5a80c4dad3b0:/mnt/afs/smartbrain/zhukangkang/code/common_data_pipeline/stage3_valid# CUDA_VISIBLE_DEVICES=0,1,2,4 python stage3_vadlid_pipeline.py --input /mnt/afs/smartbrain/zhukangkang/llavaOV_sensenova_inf_0_5w_steps_47000_5000/checkpoints/checkpoint_rank0_300.json
INFO 07-30 22:38:05 [init.py:243] Automatically detected platform cuda.
2025-07-30 22:38:08,151 [INFO] 初始化 tokenizer …
2025-07-30 22:38:08,390 [INFO] 初始化 vLLM 引擎 … (耗时 1-3 分钟)
INFO 07-30 22:38:08 [init.py:31] Available plugins for group vllm.general_plugins:
INFO 07-30 22:38:08 [init.py:33] - lora_filesystem_resolver → vllm.plugins.lora_resolvers.filesystem_resolver:register_filesystem_resolver
INFO 07-30 22:38:08 [init.py:36] All plugins in this group will be loaded. Set VLLM_PLUGINS to control which plugins to load.
INFO 07-30 22:38:16 [config.py:793] This model supports multiple tasks: {‘classify’, ‘score’, ‘embed’, ‘reward’, ‘generate’}. Defaulting to ‘generate’.
INFO 07-30 22:38:17 [config.py:1875] Defaulting to use ray for distributed inference
INFO 07-30 22:38:17 [config.py:2118] Chunked prefill is enabled with max_num_batched_tokens=16384.
INFO 07-30 22:38:18 [core.py:438] Waiting for init message from front-end.
INFO 07-30 22:38:18 [core.py:65] Initializing a V1 LLM engine (v0.9.0) with config: model=‘/mnt/afs/share_models/git_models/OpenGVLab/InternVL3-78B’, speculative_config=None, tokenizer=‘/mnt/afs/share_models/git_models/OpenGVLab/InternVL3-78B’, skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=8, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend=‘auto’, disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=‘’), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=/mnt/afs/share_models/git_models/OpenGVLab/InternVL3-78B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={“level”: 3, “custom_ops”: [“none”], “splitting_ops”: [“vllm.unified_attention”, “vllm.unified_attention_with_output”], “compile_sizes”: , “inductor_compile_config”: {“enable_auto_functionalized_v2”: false}, “use_cudagraph”: true, “cudagraph_num_of_warmups”: 1, “cudagraph_capture_sizes”: [512, 504, 496, 488, 480, 472, 464, 456, 448, 440, 432, 424, 416, 408, 400, 392, 384, 376, 368, 360, 352, 344, 336, 328, 320, 312, 304, 296, 288, 280, 272, 264, 256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 4, 2, 1], “max_capture_size”: 512}
2025-07-30 22:38:20,109 INFO worker.py:1879 – Started a local Ray instance. View the dashboard at http://127.0.0.1:8265
INFO 07-30 22:38:22 [ray_utils.py:333] No current placement group found. Creating a new placement group.
WARNING 07-30 22:38:22 [ray_utils.py:340] The number of required GPUs exceeds the total number of available GPUs in the placement group.
INFO 07-30 22:38:32 [ray_utils.py:231] Waiting for creating a placement group of specs for 10 seconds. specs=[{‘GPU’: 1.0, ‘node:172.17.0.8’: 0.001}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}]. Check ray status and ray list nodes to see if you have enough resources, and make sure the IP addresses used by ray cluster are the same as VLLM_HOST_IP environment variable specified in each node if you are running on a multi-node.
INFO 07-30 22:38:52 [ray_utils.py:231] Waiting for creating a placement group of specs for 30 seconds. specs=[{‘GPU’: 1.0, ‘node:172.17.0.8’: 0.001}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}]. Check ray status and ray list nodes to see if you have enough resources, and make sure the IP addresses used by ray cluster are the same as VLLM_HOST_IP environment variable specified in each node if you are running on a multi-node.
INFO 07-30 22:39:32 [ray_utils.py:231] Waiting for creating a placement group of specs for 70 seconds. specs=[{‘GPU’: 1.0, ‘node:172.17.0.8’: 0.001}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}]. Check ray status and ray list nodes to see if you have enough resources, and make sure the IP addresses used by ray cluster are the same as VLLM_HOST_IP environment variable specified in each node if you are running on a multi-node.
INFO 07-30 22:40:52 [ray_utils.py:231] Waiting for creating a placement group of specs for 150 seconds. specs=[{‘GPU’: 1.0, ‘node:172.17.0.8’: 0.001}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}]. Check ray status and ray list nodes to see if you have enough resources, and make sure the IP addresses used by ray cluster are the same as VLLM_HOST_IP environment variable specified in each node if you are running on a multi-node.
INFO 07-30 22:43:32 [ray_utils.py:231] Waiting for creating a placement group of specs for 310 seconds. specs=[{‘GPU’: 1.0, ‘node:172.17.0.8’: 0.001}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}, {‘GPU’: 1.0}]. Check ray status and ray list nodes to see if you have enough resources, and make sure the IP addresses used by ray cluster are the same as VLLM_HOST_IP environment variable specified in each node if you are running on a multi-node.

你的日志显示vLLM一直卡在“Waiting for creating a placement group of specs…”,根本原因是你请求了8张GPU(tensor_parallel_size=8),但Ray实际只检测到4张GPU(CUDA_VISIBLE_DEVICES=0,1,2,4),导致placement group无法分配足够资源,进而无限等待。官方文档相关issue有详细说明。

解决方法:

  • 要么把tensor_parallel_size参数改为4(和你实际可用GPU数一致);
  • 要么把CUDA_VISIBLE_DEVICES设置为8张卡(如0,1,2,3,4,5,6,7),确保Ray能分配到8张卡。

否则,Ray永远无法满足vLLM的资源请求,程序会一直卡住。

需要详细的多卡/多节点配置示例吗?

Sources:

如何更改使用4卡

我的原启动命令:

CUDA_VISIBLE_DEVICES=0,1,2,4 python stage3_vadlid_pipeline.py --input /mnt/afs/smartbrain/zhukangkang/llavaOV_sensenova_inf_0_5w_steps_47000_5000/checkpoints/checkpoint_rank0_300.json 

我的脚本:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Stage-3 Valid – InternVL-78B 本地 vLLM 批推理版本
=================================================
相比原来的 client/HTTP 版本,本脚本直接在本地通过 vLLM 引擎调用 `LLM.generate()`,
可显著提升推理吞吐。

核心流程保持不变:
1. 逐条构建提示词 (宽松评估标准)
2. 批量调用 vLLM 生成判断
3. 解析判断 → 置信度 / 正确与否
4. 保存 valid / invalid / report

仅保留最关键功能,去掉远程 API / checkpoint,保持脚本精简易跑。
"""

from __future__ import annotations

import os
import sys
import argparse
import logging
import json
import time
from dataclasses import dataclass
from typing import List, Dict, Tuple

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams, EngineArgs  # type: ignore

# ---------------------------------------------------------------------------
# 共用结构 – 直接在本文件中重新实现所需工具
# ---------------------------------------------------------------------------

import re
import io
import requests
from PIL import Image  # type: ignore


@dataclass
class ValidationResult:
    """单条验证输出"""

    judgment: str
    confidence: float
    is_correct: bool
    raw_output: str
    reasoning: str


class JudgmentParser:
    """解析模型自然语言输出,得到宽松判定"""

    correct_indicators = {
        "正确",
        "基本正确",
        "大致正确",
        "部分正确",
        "大部分正确",
        "总体正确",
        "符合要求",
        "accurate",
        "mostly correct",
        "partially correct",
        "true",
        "yes",
    }

    @staticmethod
    def parse_judgment(resp: str) -> Tuple[bool, float, str, str]:
        if not resp:
            return False, 0.0, "无响应", ""

        clean = JudgmentParser._clean(resp)
        judgment = JudgmentParser._extract_judgment(clean)
        confidence = JudgmentParser._extract_confidence(clean)
        reasoning = JudgmentParser._extract_reasoning(clean)
        is_correct = JudgmentParser._is_correct(judgment)
        return is_correct, confidence, judgment, reasoning

    @staticmethod
    def _clean(text: str) -> str:
        text = re.sub(r"<\|.*?\|>", "", text)
        return re.sub(r"\s+", " ", text).strip()

    @staticmethod
    def _extract_judgment(text: str) -> str:
        m = re.search(r"判断[::]\s*(.+?)(?:\n|置信度|理由|$)", text, re.I | re.S)
        return m.group(1).strip() if m else text

    @staticmethod
    def _extract_confidence(text: str) -> float:
        m = re.search(r"置信度[::]\s*(\d*\.?\d+)", text)
        if m:
            try:
                return min(float(m.group(1)), 1.0)
            except Exception:
                pass
        # 粗略估计
        lower = text.lower()
        if any(k in lower for k in ("完全正确", "absolutely", "显然")):
            return 0.95
        if "部分" in text or "partially" in lower:
            return 0.6
        return 0.7

    @staticmethod
    def _extract_reasoning(text: str) -> str:
        m = re.search(r"理由[::]\s*(.+)", text, re.I | re.S)
        return m.group(1).strip() if m else "无详细理由"

    @staticmethod
    def _is_correct(judgment: str) -> bool:
        j = judgment.lower()
        if any(ind in j for ind in JudgmentParser.correct_indicators):
            return not any(neg in j for neg in ("不正确", "错误", "incorrect"))
        return False


# ---------------- 数据集 I/O ----------------


def load_dataset(path: str) -> List[Dict]:
    logger.info(f"加载数据集: {path}")
    try:
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
    except Exception as e:
        logger.error(f"读取失败: {e}")
        sys.exit(1)

    if not isinstance(data, list):
        logger.error("顶层应为列表")
        sys.exit(1)

    processed: List[Dict] = []
    for item in data:
        if not isinstance(item, dict):
            continue

        # 兼容字段
        image = item.get("image") or item.get("image_path")
        q = item.get("question") or item.get("problem")
        a = item.get("answer") or item.get("gt_answer")

        if not (image and q and a):
            continue

        processed.append({"image": image, "question": q, "answer": a, **item})

    logger.info(f"标准化后样本数: {len(processed)}/{len(data)}")
    return processed


def save_json(path: str, data: List[Dict]):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    logger.info(f"保存 {len(data)} 条 → {path}")


def clean_validation_data(data: List[Dict]) -> List[Dict]:
    return [{k: v for k, v in item.items() if k != "gt_valid"} for item in data]


def generate_report(correct: List[Dict], incorrect: List[Dict], out_dir: str, fname: str, start_time: float, api: str):
    total = len(correct) + len(incorrect)
    duration = time.time() - start_time
    report = (
        f"Stage-3 Valid 报告\n====================\n"
        f"文件: {fname}\n总样本: {total}\n"
        f"正确: {len(correct)}  ({len(correct)/total*100:.2f}%)\n"
        f"错误: {len(incorrect)} ({len(incorrect)/total*100:.2f}%)\n"
        f"耗时: {duration:.2f}s\n"
        f"模式: 本地 vLLM pipeline\n"
    )
    path = os.path.join(out_dir, f"{fname}_report.txt")
    with open(path, "w", encoding="utf-8") as f:
        f.write(report)
    logger.info(f"报告已保存 → {path}")


logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
)

# ---------------------------------------------------------------------------
# 配置
# ---------------------------------------------------------------------------
@dataclass
class ValidationConfig:
    model_path: str = "/mnt/afs/share_models/git_models/OpenGVLab/InternVL3-78B"
    batch_size: int = 128
    max_new_tokens: int = 1024
    temperature: float = 0.0
    top_p: float = 1.0
    gpu_util: float = 0.95  # 依据显存情况微调
    # ---- 追加:大上下文 & 批推理参数 ----
    max_model_len: int = 8192          # 上下文长度
    # max_num_batched_tokens: int = 65536 # 每批最大 token 数
    # max_num_seqs: int = 512             # 每批最大序列数
    # ---- GPU 并行 ----
    tensor_parallel_size: int = 8       # 默认使用 8 张 GPU 做张量并行

# ---------------------------------------------------------------------------
# 本地 vLLM 推理器
# ---------------------------------------------------------------------------
class InternVLPipelineInference:
    """本地 vLLM 批推理封装。"""

    def __init__(self, cfg: ValidationConfig):
        self.cfg = cfg

        logger.info("初始化 tokenizer …")
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_path, trust_remote_code=True)

        logger.info("初始化 vLLM 引擎 … (耗时 1-3 分钟)")
        # 兼容旧配置,按需动态加入可选参数
        # 根据环境变量 WORLD_SIZE(分布式场景)或配置中的 tensor_parallel_size 设定卡数
        tp_size = int(os.environ.get("WORLD_SIZE", cfg.tensor_parallel_size))

        eng_kwargs = dict(
            model=cfg.model_path,
            tensor_parallel_size=max(1, tp_size),
            gpu_memory_utilization=cfg.gpu_util,
            trust_remote_code=True,
        )

        if hasattr(cfg, "max_model_len"):
            eng_kwargs["max_model_len"] = cfg.max_model_len
        if hasattr(cfg, "max_num_batched_tokens"):
            eng_kwargs["max_num_batched_tokens"] = cfg.max_num_batched_tokens
        if hasattr(cfg, "max_num_seqs"):
            eng_kwargs["max_num_seqs"] = cfg.max_num_seqs

        eng_args = EngineArgs(**eng_kwargs)
        self.llm = LLM(**eng_args.__dict__)

        stop_tokens = ["<|endoftext|>", "<|im_end|>"]
        stop_token_ids = [self.tokenizer.convert_tokens_to_ids(tok) for tok in stop_tokens]
        self.sp = SamplingParams(
            temperature=cfg.temperature,
            top_p=cfg.top_p,
            max_tokens=cfg.max_new_tokens,
            stop_token_ids=stop_token_ids,
            skip_special_tokens=True,
        )

    # -------------------- prompt 构建 --------------------
    def _build_prompt(self, sample: Dict) -> str:
        question = sample.get("question", "")
        answer = sample.get("answer", "")

        # 指令部分
        instr = (
            "你是一个严谨的、注重事实的多模态问答评估员。请基于图片内容,判断\"待判断答案\"是否准确地回答了\"问题\"。\n\n"
            "**要求:**\n"
            "1. **事实为先**: 你的判断必须严格基于图片内容,不能有任何想象或推断。\n"
            "2. **宽松标准**: '完全正确'、'基本正确'、'大致正确'、'部分正确'都视为【正确】。只有'完全错误'、'严重错误'、'明显不符'才视为【错误】。\n"
            "3. **输出格式必须如下**:\n\n"
            "判断:[正确/基本正确/部分正确/错误]\n"
            "置信度:[0.0-1.0]\n"
            "理由:[简要说明判断依据,指出答案中正确或错误的关键点]\n\n"
            "--- 以下是待评估内容 ---\n"
            f"问题:{question}\n待判断答案:{answer}"
        )

        # 使用官方推荐的 chat_template 多模态格式
        messages = [
            {
                "role": "user", 
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": instr}
                ]
            }
        ]

        return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    # -------------------- 核心批推理 --------------------
    def validate_batch(self, batch_data: List[Dict]) -> List[ValidationResult]:
        requests_list = []
        for sample in batch_data:
            prompt = self._build_prompt(sample)

            img_path = sample.get("image", "")
            if not img_path:
                # 占位错误
                prompt += "\n[错误: 未提供图像]"
                requests_list.append({"prompt": prompt})
                continue

            try:
                if img_path.startswith("http"):
                    resp = requests.get(img_path, timeout=15)
                    resp.raise_for_status()
                    image = Image.open(io.BytesIO(resp.content)).convert("RGB")
                else:
                    image = Image.open(img_path).convert("RGB")
            except Exception as e:
                logger.warning(f"加载图片失败: {img_path} - {e}")
                prompt += "\n[错误: 图像加载失败]"
                requests_list.append({"prompt": prompt})
                continue

            requests_list.append({
                "prompt": prompt,
                "multi_modal_data": {"image": image},
            })

        outputs = self.llm.generate(requests_list, self.sp)

        results: List[ValidationResult] = []
        for out in outputs:
            raw = out.outputs[0].text.strip()
            is_corr, conf, judgment, reasoning = JudgmentParser.parse_judgment(raw)
            results.append(
                ValidationResult(
                    judgment=judgment,
                    confidence=conf,
                    is_correct=is_corr,
                    raw_output=raw,
                    reasoning=reasoning,
                )
            )
        return results

# ---------------------------------------------------------------------------
# 入口逻辑 (简化,无 checkpoint)
# ---------------------------------------------------------------------------

def parse_args():
    p = argparse.ArgumentParser("Stage-3 Valid – 本地 vLLM Pipeline 版")
    p.add_argument("--input", required=True, help="输入 JSON 文件")
    p.add_argument(
        "--output",
        default="/mnt/afs/smartbrain/zhukangkang/code/common_data_pipeline/stage1_cls/output_stage1_pipeline",
        help="输出根目录 (默认: /mnt/afs/smartbrain/zhukangkang/code/common_data_pipeline/stage1_cls/output_stage1_pipeline)",
    )
    p.add_argument("--model", default="/mnt/afs/share_models/git_models/OpenGVLab/InternVL3-78B", help="模型路径")
    p.add_argument("--batch-size", type=int, default=128, help="批推理大小")
    p.add_argument("--gpu-util", type=float, default=0.85, help="GPU 显存利用率上限")
    p.add_argument("--tp-size", type=int, default=8, help="Tensor parallel size (GPU 卡数,默认 8)")
    p.add_argument("--debug", action="store_true")
    return p.parse_args()


def main():
    args = parse_args()
    if args.debug:
        logging.getLogger().setLevel(logging.DEBUG)

    cfg = ValidationConfig(
        model_path=args.model,
        batch_size=args.batch_size,
        gpu_util=args.gpu_util,
        tensor_parallel_size=args.tp_size,
    )

    inference = InternVLPipelineInference(cfg)

    data = load_dataset(args.input)
    if not data:
        logger.error("数据为空,退出")
        sys.exit(1)

    # 输出目录遵循旧脚本逻辑: <output_root>/<input_filename>/
    input_fname = os.path.splitext(os.path.basename(args.input))[0]
    run_out_dir = os.path.join(args.output, input_fname)
    os.makedirs(run_out_dir, exist_ok=True)

    correct: List[Dict] = []
    incorrect: List[Dict] = []

    total = len(data)
    logger.info(f"共有 {total} 条样本,batch_size={cfg.batch_size}")

    for i in range(0, total, cfg.batch_size):
        batch = data[i : i + cfg.batch_size]
        logger.info(f"处理 {i+1}‒{min(i+cfg.batch_size, total)} …")
        results = inference.validate_batch(batch)
        for sample, res in zip(batch, results):
            sample["gt_valid"] = {
                "validation": {
                    "model": "InternVL-78B-pipeline",
                    "judgment": res.judgment,
                    "confidence": res.confidence,
                    "reasoning": res.reasoning,
                    "is_correct": res.is_correct,
                    "raw_output": res.raw_output,
                }
            }
            (correct if res.is_correct else incorrect).append(sample)

    # 保存,与老脚本命名保持一致
    valid_path = os.path.join(run_out_dir, f"{input_fname}_valid.json")
    invalid_path = os.path.join(run_out_dir, f"{input_fname}_invalid.json")

    save_json(valid_path, clean_validation_data(correct))
    save_json(invalid_path, incorrect)

    # 报告
    generate_report(correct, incorrect, run_out_dir, input_fname, time.time(), "local-pipeline")

    logger.info("✅ 全部完成!")


if __name__ == "__main__":
    main()

我的报错:
Loading safetensors checkpoint shards: 100% Completed | 33/33 [12:00<00:00, 21.83s/it]
(VllmWorker rank=0 pid=493016)
(VllmWorker rank=0 pid=493016) INFO 07-30 23:15:58 [default_loader.py:280] Loading weights took 720.54 seconds
(VllmWorker rank=1 pid=493017) INFO 07-30 23:15:58 [gpu_model_runner.py:1549] Model loading took 39.5328 GiB and 720.879685 seconds
(VllmWorker rank=2 pid=493018) INFO 07-30 23:15:58 [gpu_model_runner.py:1549] Model loading took 39.5328 GiB and 720.907523 seconds
(VllmWorker rank=3 pid=493019) INFO 07-30 23:15:58 [gpu_model_runner.py:1549] Model loading took 39.5328 GiB and 720.933947 seconds
(VllmWorker rank=0 pid=493016) INFO 07-30 23:15:58 [gpu_model_runner.py:1549] Model loading took 39.5328 GiB and 720.996336 seconds
(VllmWorker rank=2 pid=493018) INFO 07-30 23:15:58 [gpu_model_runner.py:1863] Encoder cache will be initialized with a budget of 16384 tokens, and profiled with 4 video items of the maximum feature size.
(VllmWorker rank=0 pid=493016) INFO 07-30 23:15:58 [gpu_model_runner.py:1863] Encoder cache will be initialized with a budget of 16384 tokens, and profiled with 4 video items of the maximum feature size.
(VllmWorker rank=1 pid=493017) INFO 07-30 23:15:58 [gpu_model_runner.py:1863] Encoder cache will be initialized with a budget of 16384 tokens, and profiled with 4 video items of the maximum feature size.
(VllmWorker rank=3 pid=493019) INFO 07-30 23:15:58 [gpu_model_runner.py:1863] Encoder cache will be initialized with a budget of 16384 tokens, and profiled with 4 video items of the maximum feature size.
(VllmWorker rank=3 pid=493019) INFO 07-30 23:16:27 [backends.py:459] Using cache directory: /root/.cache/vllm/torch_compile_cache/c8a7fdfbdd/rank_3_0 for vLLM’s torch.compile
(VllmWorker rank=3 pid=493019) INFO 07-30 23:16:27 [backends.py:469] Dynamo bytecode transform time: 11.93 s
(VllmWorker rank=0 pid=493016) INFO 07-30 23:16:27 [backends.py:459] Using cache directory: /root/.cache/vllm/torch_compile_cache/c8a7fdfbdd/rank_0_0 for vLLM’s torch.compile
(VllmWorker rank=0 pid=493016) INFO 07-30 23:16:27 [backends.py:469] Dynamo bytecode transform time: 12.00 s
(VllmWorker rank=2 pid=493018) INFO 07-30 23:16:27 [backends.py:459] Using cache directory: /root/.cache/vllm/torch_compile_cache/c8a7fdfbdd/rank_2_0 for vLLM’s torch.compile
(VllmWorker rank=2 pid=493018) INFO 07-30 23:16:27 [backends.py:469] Dynamo bytecode transform time: 12.06 s
(VllmWorker rank=1 pid=493017) INFO 07-30 23:16:28 [backends.py:459] Using cache directory: /root/.cache/vllm/torch_compile_cache/c8a7fdfbdd/rank_1_0 for vLLM’s torch.compile
(VllmWorker rank=1 pid=493017) INFO 07-30 23:16:28 [backends.py:469] Dynamo bytecode transform time: 12.16 s
(VllmWorker rank=3 pid=493019) INFO 07-30 23:16:38 [backends.py:132] Directly load the compiled graph(s) for shape None from the cache, took 9.272 s
(VllmWorker rank=0 pid=493016) INFO 07-30 23:16:38 [backends.py:132] Directly load the compiled graph(s) for shape None from the cache, took 9.273 s
(VllmWorker rank=2 pid=493018) INFO 07-30 23:16:38 [backends.py:132] Directly load the compiled graph(s) for shape None from the cache, took 9.598 s
(VllmWorker rank=1 pid=493017) INFO 07-30 23:16:38 [backends.py:132] Directly load the compiled graph(s) for shape None from the cache, took 9.539 s
(VllmWorker rank=3 pid=493019) INFO 07-30 23:16:40 [monitor.py:33] torch.compile takes 11.93 s in total
(VllmWorker rank=0 pid=493016) INFO 07-30 23:16:40 [monitor.py:33] torch.compile takes 12.00 s in total
(VllmWorker rank=2 pid=493018) INFO 07-30 23:16:40 [monitor.py:33] torch.compile takes 12.06 s in total
(VllmWorker rank=1 pid=493017) INFO 07-30 23:16:40 [monitor.py:33] torch.compile takes 12.16 s in total
INFO 07-30 23:16:46 [kv_cache_utils.py:637] GPU KV cache size: 232,720 tokens
INFO 07-30 23:16:46 [kv_cache_utils.py:640] Maximum concurrency for 8,192 tokens per request: 28.41x
INFO 07-30 23:16:46 [kv_cache_utils.py:637] GPU KV cache size: 231,488 tokens
INFO 07-30 23:16:46 [kv_cache_utils.py:640] Maximum concurrency for 8,192 tokens per request: 28.26x
INFO 07-30 23:16:46 [kv_cache_utils.py:637] GPU KV cache size: 231,488 tokens
INFO 07-30 23:16:46 [kv_cache_utils.py:640] Maximum concurrency for 8,192 tokens per request: 28.26x
INFO 07-30 23:16:46 [kv_cache_utils.py:637] GPU KV cache size: 237,632 tokens
INFO 07-30 23:16:46 [kv_cache_utils.py:640] Maximum concurrency for 8,192 tokens per request: 29.01x
(VllmWorker rank=3 pid=493019) INFO 07-30 23:17:25 [custom_all_reduce.py:195] Registering 10560 cuda graph addresses
(VllmWorker rank=2 pid=493018) INFO 07-30 23:17:26 [custom_all_reduce.py:195] Registering 10560 cuda graph addresses
(VllmWorker rank=0 pid=493016) INFO 07-30 23:17:26 [custom_all_reduce.py:195] Registering 10560 cuda graph addresses
(VllmWorker rank=1 pid=493017) INFO 07-30 23:17:34 [custom_all_reduce.py:195] Registering 10560 cuda graph addresses
(VllmWorker rank=3 pid=493019) INFO 07-30 23:17:34 [gpu_model_runner.py:1933] Graph capturing finished in 48 secs, took 4.59 GiB
(VllmWorker rank=0 pid=493016) INFO 07-30 23:17:34 [gpu_model_runner.py:1933] Graph capturing finished in 48 secs, took 4.59 GiB
(VllmWorker rank=2 pid=493018) INFO 07-30 23:17:34 [gpu_model_runner.py:1933] Graph capturing finished in 48 secs, took 4.59 GiB
(VllmWorker rank=1 pid=493017) INFO 07-30 23:17:34 [gpu_model_runner.py:1933] Graph capturing finished in 48 secs, took 4.59 GiB
INFO 07-30 23:17:34 [core.py:167] init engine (profile, create kv cache, warmup model) took 96.21 seconds
Token indices sequence length is longer than the specified maximum sequence length for this model (8318 > 8192). Running this sequence through the model will result in indexing errors
2025-07-30 23:17:35,649 [INFO] 加载数据集: /mnt/afs/smartbrain/zhukangkang/llavaOV_sensenova_inf_0_5w_steps_47000_5000/checkpoints/checkpoint_rank0_300.json
2025-07-30 23:17:35,659 [INFO] 标准化后样本数: 300/300
2025-07-30 23:17:35,661 [INFO] 共有 300 条样本,batch_size=128
2025-07-30 23:17:35,661 [INFO] 处理 1‒128 …
Traceback (most recent call last):
File “/mnt/afs/smartbrain/zhukangkang/code/common_data_pipeline/stage3_valid/stage3_vadlid_pipeline.py”, line 442, in
main()
File “/mnt/afs/smartbrain/zhukangkang/code/common_data_pipeline/stage3_valid/stage3_vadlid_pipeline.py”, line 414, in main
results = inference.validate_batch(batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/mnt/afs/smartbrain/zhukangkang/code/common_data_pipeline/stage3_valid/stage3_vadlid_pipeline.py”, line 317, in validate_batch
prompt = self._build_prompt(sample)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/mnt/afs/smartbrain/zhukangkang/code/common_data_pipeline/stage3_valid/stage3_vadlid_pipeline.py”, line 311, in _build_prompt
return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/miniconda3/envs/SenseRL/lib/python3.11/site-packages/transformers/tokenization_utils_base.py”, line 1652, in apply_chat_template
rendered_chat, generation_indices = render_jinja_template(
^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/miniconda3/envs/SenseRL/lib/python3.11/site-packages/transformers/utils/chat_template_utils.py”, line 498, in render_jinja_template
rendered_chat = compiled_template.render(
^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/miniconda3/envs/SenseRL/lib/python3.11/site-packages/jinja2/environment.py”, line 1295, in render
self.environment.handle_exception()
File “/usr/local/lib/miniconda3/envs/SenseRL/lib/python3.11/site-packages/jinja2/environment.py”, line 942, in handle_exception
raise rewrite_traceback_stack(source=source)
File “”, line 23, in top-level template code
TypeError: can only concatenate str (not “list”) to str
(SenseRL) root@5a80c4dad3b0:/mnt/afs/smartbrain/zhukangkang/code/common_data_pipeline/stage3_valid#

我的脚本:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Stage-3 Valid – InternVL-78B 本地 vLLM 批推理版本
=================================================
相比原来的 client/HTTP 版本,本脚本直接在本地通过 vLLM 引擎调用 `LLM.generate()`,
可显著提升推理吞吐。

核心流程保持不变:
1. 逐条构建提示词 (宽松评估标准)
2. 批量调用 vLLM 生成判断
3. 解析判断 → 置信度 / 正确与否
4. 保存 valid / invalid / report

仅保留最关键功能,去掉远程 API / checkpoint,保持脚本精简易跑。
"""

from __future__ import annotations

import os
import sys
import argparse
import logging
import json
import time
from dataclasses import dataclass
from typing import List, Dict, Tuple

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams, EngineArgs  # type: ignore

# ---------------------------------------------------------------------------
# 共用结构 – 直接在本文件中重新实现所需工具
# ---------------------------------------------------------------------------

import re
import io
import requests
from PIL import Image  # type: ignore


@dataclass
class ValidationResult:
    """单条验证输出"""

    judgment: str
    confidence: float
    is_correct: bool
    raw_output: str
    reasoning: str


class JudgmentParser:
    """解析模型自然语言输出,得到宽松判定"""

    correct_indicators = {
        "正确",
        "基本正确",
        "大致正确",
        "部分正确",
        "大部分正确",
        "总体正确",
        "符合要求",
        "accurate",
        "mostly correct",
        "partially correct",
        "true",
        "yes",
    }

    @staticmethod
    def parse_judgment(resp: str) -> Tuple[bool, float, str, str]:
        if not resp:
            return False, 0.0, "无响应", ""

        clean = JudgmentParser._clean(resp)
        judgment = JudgmentParser._extract_judgment(clean)
        confidence = JudgmentParser._extract_confidence(clean)
        reasoning = JudgmentParser._extract_reasoning(clean)
        is_correct = JudgmentParser._is_correct(judgment)
        return is_correct, confidence, judgment, reasoning

    @staticmethod
    def _clean(text: str) -> str:
        text = re.sub(r"<\|.*?\|>", "", text)
        return re.sub(r"\s+", " ", text).strip()

    @staticmethod
    def _extract_judgment(text: str) -> str:
        m = re.search(r"判断[::]\s*(.+?)(?:\n|置信度|理由|$)", text, re.I | re.S)
        return m.group(1).strip() if m else text

    @staticmethod
    def _extract_confidence(text: str) -> float:
        m = re.search(r"置信度[::]\s*(\d*\.?\d+)", text)
        if m:
            try:
                return min(float(m.group(1)), 1.0)
            except Exception:
                pass
        # 粗略估计
        lower = text.lower()
        if any(k in lower for k in ("完全正确", "absolutely", "显然")):
            return 0.95
        if "部分" in text or "partially" in lower:
            return 0.6
        return 0.7

    @staticmethod
    def _extract_reasoning(text: str) -> str:
        m = re.search(r"理由[::]\s*(.+)", text, re.I | re.S)
        return m.group(1).strip() if m else "无详细理由"

    @staticmethod
    def _is_correct(judgment: str) -> bool:
        j = judgment.lower()
        if any(ind in j for ind in JudgmentParser.correct_indicators):
            return not any(neg in j for neg in ("不正确", "错误", "incorrect"))
        return False


# ---------------- 数据集 I/O ----------------


def load_dataset(path: str) -> List[Dict]:
    logger.info(f"加载数据集: {path}")
    try:
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
    except Exception as e:
        logger.error(f"读取失败: {e}")
        sys.exit(1)

    if not isinstance(data, list):
        logger.error("顶层应为列表")
        sys.exit(1)

    processed: List[Dict] = []
    for item in data:
        if not isinstance(item, dict):
            continue

        # 兼容字段
        image = item.get("image") or item.get("image_path")
        q = item.get("question") or item.get("problem")
        a = item.get("answer") or item.get("gt_answer")

        if not (image and q and a):
            continue

        processed.append({"image": image, "question": q, "answer": a, **item})

    logger.info(f"标准化后样本数: {len(processed)}/{len(data)}")
    return processed


def save_json(path: str, data: List[Dict]):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    logger.info(f"保存 {len(data)} 条 → {path}")


def clean_validation_data(data: List[Dict]) -> List[Dict]:
    return [{k: v for k, v in item.items() if k != "gt_valid"} for item in data]


def generate_report(correct: List[Dict], incorrect: List[Dict], out_dir: str, fname: str, start_time: float, api: str):
    total = len(correct) + len(incorrect)
    duration = time.time() - start_time
    report = (
        f"Stage-3 Valid 报告\n====================\n"
        f"文件: {fname}\n总样本: {total}\n"
        f"正确: {len(correct)}  ({len(correct)/total*100:.2f}%)\n"
        f"错误: {len(incorrect)} ({len(incorrect)/total*100:.2f}%)\n"
        f"耗时: {duration:.2f}s\n"
        f"模式: 本地 vLLM pipeline\n"
    )
    path = os.path.join(out_dir, f"{fname}_report.txt")
    with open(path, "w", encoding="utf-8") as f:
        f.write(report)
    logger.info(f"报告已保存 → {path}")


logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
)

# ---------------------------------------------------------------------------
# 配置
# ---------------------------------------------------------------------------
@dataclass
class ValidationConfig:
    model_path: str = "/mnt/afs/share_models/git_models/OpenGVLab/InternVL3-78B"
    batch_size: int = 128
    max_new_tokens: int = 1024
    temperature: float = 0.0
    top_p: float = 1.0
    gpu_util: float = 0.95  # 依据显存情况微调
    # ---- 追加:大上下文 & 批推理参数 ----
    max_model_len: int = 8192          # 上下文长度
    # max_num_batched_tokens: int = 65536 # 每批最大 token 数
    # max_num_seqs: int = 512             # 每批最大序列数
    # ---- GPU 并行 ----
    tensor_parallel_size: int = 4       # 默认使用 4 张 GPU 做张量并行

# ---------------------------------------------------------------------------
# 本地 vLLM 推理器
# ---------------------------------------------------------------------------
class InternVLPipelineInference:
    """本地 vLLM 批推理封装。"""

    def __init__(self, cfg: ValidationConfig):
        self.cfg = cfg

        logger.info("初始化 tokenizer …")
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_path, trust_remote_code=True)

        logger.info("初始化 vLLM 引擎 … (耗时 1-3 分钟)")
        # 兼容旧配置,按需动态加入可选参数
        # 自动检测可用GPU数量
        visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
        if visible_devices:
            # 从 CUDA_VISIBLE_DEVICES 中计算可用GPU数量
            gpu_count = len([x.strip() for x in visible_devices.split(",") if x.strip()])
            tp_size = gpu_count
            logger.info(f"检测到 CUDA_VISIBLE_DEVICES={visible_devices},自动设置 tensor_parallel_size={tp_size}")
        else:
            # 使用配置中的默认值或环境变量 WORLD_SIZE
            tp_size = int(os.environ.get("WORLD_SIZE", cfg.tensor_parallel_size))
            logger.info(f"使用配置的 tensor_parallel_size={tp_size}")

        eng_kwargs = dict(
            model=cfg.model_path,
            tensor_parallel_size=max(1, tp_size),
            gpu_memory_utilization=cfg.gpu_util,
            trust_remote_code=True,
        )

        if hasattr(cfg, "max_model_len"):
            eng_kwargs["max_model_len"] = cfg.max_model_len
        if hasattr(cfg, "max_num_batched_tokens"):
            eng_kwargs["max_num_batched_tokens"] = cfg.max_num_batched_tokens
        if hasattr(cfg, "max_num_seqs"):
            eng_kwargs["max_num_seqs"] = cfg.max_num_seqs

        eng_args = EngineArgs(**eng_kwargs)
        self.llm = LLM(**eng_args.__dict__)

        # 正确获取 stop token IDs
        stop_token_ids = []
        if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
            stop_token_ids.append(self.tokenizer.eos_token_id)
        
        # 尝试获取 <|im_end|> 的 token ID
        try:
            im_end_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
            if im_end_id != self.tokenizer.unk_token_id:  # 确保不是 unknown token
                stop_token_ids.append(im_end_id)
        except:
            pass
            
        # 尝试获取 <|endoftext|> 的 token ID  
        try:
            endoftext_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
            if endoftext_id != self.tokenizer.unk_token_id:  # 确保不是 unknown token
                stop_token_ids.append(endoftext_id)
        except:
            pass

        self.sp = SamplingParams(
            temperature=cfg.temperature,
            top_p=cfg.top_p,
            max_tokens=cfg.max_new_tokens,
            stop_token_ids=stop_token_ids,
            skip_special_tokens=True,
        )

    # -------------------- prompt 构建 --------------------
    def _build_prompt(self, sample: Dict) -> str:
        question = sample.get("question", "")
        answer = sample.get("answer", "")

        # 指令部分
        instr = (
            "你是一个严谨的、注重事实的多模态问答评估员。请基于图片内容,判断\"待判断答案\"是否准确地回答了\"问题\"。\n\n"
            "**要求:**\n"
            "1. **事实为先**: 你的判断必须严格基于图片内容,不能有任何想象或推断。\n"
            "2. **宽松标准**: '完全正确'、'基本正确'、'大致正确'、'部分正确'都视为【正确】。只有'完全错误'、'严重错误'、'明显不符'才视为【错误】。\n"
            "3. **输出格式必须如下**:\n\n"
            "判断:[正确/基本正确/部分正确/错误]\n"
            "置信度:[0.0-1.0]\n"
            "理由:[简要说明判断依据,指出答案中正确或错误的关键点]\n\n"
            "--- 以下是待评估内容 ---\n"
            f"问题:{question}\n待判断答案:{answer}"
        )

        # 使用官方推荐的 chat_template 多模态格式
        messages = [
            {
                "role": "user", 
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": instr}
                ]
            }
        ]

        return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    # -------------------- 核心批推理 --------------------
    def validate_batch(self, batch_data: List[Dict]) -> List[ValidationResult]:
        requests_list = []
        for sample in batch_data:
            prompt = self._build_prompt(sample)

            img_path = sample.get("image", "")
            if not img_path:
                # 占位错误
                prompt += "\n[错误: 未提供图像]"
                requests_list.append({"prompt": prompt})
                continue

            try:
                if img_path.startswith("http"):
                    resp = requests.get(img_path, timeout=15)
                    resp.raise_for_status()
                    image = Image.open(io.BytesIO(resp.content)).convert("RGB")
                else:
                    image = Image.open(img_path).convert("RGB")
            except Exception as e:
                logger.warning(f"加载图片失败: {img_path} - {e}")
                prompt += "\n[错误: 图像加载失败]"
                requests_list.append({"prompt": prompt})
                continue

            requests_list.append({
                "prompt": prompt,
                "multi_modal_data": {"image": image},
            })

        outputs = self.llm.generate(requests_list, self.sp)

        results: List[ValidationResult] = []
        for out in outputs:
            raw = out.outputs[0].text.strip()
            is_corr, conf, judgment, reasoning = JudgmentParser.parse_judgment(raw)
            results.append(
                ValidationResult(
                    judgment=judgment,
                    confidence=conf,
                    is_correct=is_corr,
                    raw_output=raw,
                    reasoning=reasoning,
                )
            )
        return results

# ---------------------------------------------------------------------------
# 入口逻辑 (简化,无 checkpoint)
# ---------------------------------------------------------------------------

def parse_args():
    p = argparse.ArgumentParser("Stage-3 Valid – 本地 vLLM Pipeline 版")
    p.add_argument("--input", required=True, help="输入 JSON 文件")
    p.add_argument(
        "--output",
        default="/mnt/afs/smartbrain/zhukangkang/code/common_data_pipeline/stage1_cls/output_stage1_pipeline",
        help="输出根目录 (默认: /mnt/afs/smartbrain/zhukangkang/code/common_data_pipeline/stage1_cls/output_stage1_pipeline)",
    )
    p.add_argument("--model", default="/mnt/afs/share_models/git_models/OpenGVLab/InternVL3-78B", help="模型路径")
    p.add_argument("--batch-size", type=int, default=128, help="批推理大小")
    p.add_argument("--gpu-util", type=float, default=0.85, help="GPU 显存利用率上限")
    p.add_argument("--tp-size", type=int, default=4, help="Tensor parallel size (GPU 卡数,默认 4)")
    p.add_argument("--debug", action="store_true")
    return p.parse_args()


def main():
    args = parse_args()
    if args.debug:
        logging.getLogger().setLevel(logging.DEBUG)

    cfg = ValidationConfig(
        model_path=args.model,
        batch_size=args.batch_size,
        gpu_util=args.gpu_util,
        tensor_parallel_size=args.tp_size,
    )

    inference = InternVLPipelineInference(cfg)

    data = load_dataset(args.input)
    if not data:
        logger.error("数据为空,退出")
        sys.exit(1)

    # 输出目录遵循旧脚本逻辑: <output_root>/<input_filename>/
    input_fname = os.path.splitext(os.path.basename(args.input))[0]
    run_out_dir = os.path.join(args.output, input_fname)
    os.makedirs(run_out_dir, exist_ok=True)

    correct: List[Dict] = []
    incorrect: List[Dict] = []

    total = len(data)
    logger.info(f"共有 {total} 条样本,batch_size={cfg.batch_size}")

    for i in range(0, total, cfg.batch_size):
        batch = data[i : i + cfg.batch_size]
        logger.info(f"处理 {i+1}‒{min(i+cfg.batch_size, total)} …")
        results = inference.validate_batch(batch)
        for sample, res in zip(batch, results):
            sample["gt_valid"] = {
                "validation": {
                    "model": "InternVL-78B-pipeline",
                    "judgment": res.judgment,
                    "confidence": res.confidence,
                    "reasoning": res.reasoning,
                    "is_correct": res.is_correct,
                    "raw_output": res.raw_output,
                }
            }
            (correct if res.is_correct else incorrect).append(sample)

    # 保存,与老脚本命名保持一致
    valid_path = os.path.join(run_out_dir, f"{input_fname}_valid.json")
    invalid_path = os.path.join(run_out_dir, f"{input_fname}_invalid.json")

    save_json(valid_path, clean_validation_data(correct))
    save_json(invalid_path, incorrect)

    # 报告
    generate_report(correct, incorrect, run_out_dir, input_fname, time.time(), "local-pipeline")

    logger.info("✅ 全部完成!")


if __name__ == "__main__":
    main()




你只需两步即可用4卡运行

  1. 命令行设置
    将启动命令中的CUDA_VISIBLE_DEVICES=0,1,2,4改为你实际想用的4张卡(如0,1,2,3),并确保--tp-size 4(或tensor_parallel_size=4):
CUDA_VISIBLE_DEVICES=0,1,2,4 python stage3_vadlid_pipeline.py --input ... --tp-size 4

注意:CUDA_VISIBLE_DEVICES里有4张卡,--tp-size必须等于4。

  1. 脚本自动适配
    你的脚本已自动检测CUDA_VISIBLE_DEVICES数量并设置tensor_parallel_size,只要命令行--tp-size 4或环境变量一致即可,无需手动改脚本。

常见问题:

  • 如果CUDA_VISIBLE_DEVICES里有4张卡,--tp-size必须等于4,否则会卡住。
  • 建议CUDA_VISIBLE_DEVICES=0,1,2,3,避免卡号不连续导致潜在问题。

参考:

如需多节点或特殊分布式配置示例请告知。

Sources:

Hi @Jasper have you tried --enforce-eager or using later versions of vllm? There may be a bug in torch.compile that manifest in distributed inference including using ray. See more discussions in [Bug]: 0.8.0(V1) RayChannelTimeoutError when inferencing DeepSeekV3 on 16 H20 with large batch size · Issue #15102 · vllm-project/vllm · GitHub

是的,vLLM官方和社区多次建议在分布式推理(尤其是Ray后端)遇到不稳定、超时、崩溃等问题时,尝试加上--enforce-eager参数。该参数会关闭torch.compile,强制使用PyTorch eager模式,能规避部分NCCL、Ray和torch.compile的兼容性bug,提升分布式稳定性。官方issue #15102相关讨论均有明确说明。

此外,升级到vLLM最新版本也能解决部分已知bug。需要详细的参数用法或升级建议吗?

Sources:

我这样进行offline pipeline推理的脚本编写正确嘛?特别是模版和batch的编写方便

import math
import numpy as np
import os
from transformers import AutoTokenizer
import json
from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.multimodal.utils import MediaConnector
import re, copy, random

PREFIX_IMG_PATH='/mnt/afs/smartbrain/zhukangkang/code/eval/data'
connector = MediaConnector(allowed_local_media_path=PREFIX_IMG_PATH)


def load_json_data(path):
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data

def save_json_data(save_path, data):
     if not data:
         return
     print("save data to `{}`".format(save_path))
     print(data[0])
     os.makedirs(os.path.dirname(save_path), exist_ok=True)
     with open(save_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)

def get_all_file_paths(folder_path):
    file_paths = []
    # 支持的图片格式
    image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp')
    
    for root, _, files in os.walk(folder_path):
        for file in files:
            if file.lower().endswith(image_extensions):
                file_paths.append(os.path.join(root, file))
    return file_paths


def get_samples_from_images(image_folder_path, max_count=None):
    """直接从图片文件夹获取样本"""
    all_image_paths = get_all_file_paths(image_folder_path)
    print(f"找到 {len(all_image_paths)} 张图片")
    
    result = []
    for img_path in tqdm(all_image_paths):
        if not os.path.exists(img_path):
            continue
        # 获取相对于PREFIX_IMG_PATH的相对路径
        rel_img_path = os.path.relpath(img_path, PREFIX_IMG_PATH)
        sample_data = {
            "img_path": rel_img_path,
            "abs_img_path": img_path,
            "isPos": None,  # 未知,因为没有标签信息
            "scenes": []    # 未知,因为没有场景信息
        }
        result.append(sample_data)
        
        # 如果设置了最大数量限制
        if max_count and len(result) >= max_count:
            break
    
    print("最终处理的图片数量: {}".format(len(result)))
    return result




if __name__ == "__main__":
    # load model
    path = "/mnt/afs/share_models/git_models/OpenGVLab/InternVL3-78B/"
    temperature = 0.5
    max_tokens = 1024
    GPUS = 4

    batch_size = 500
    save_steps = batch_size*6
    
    # 图片文件夹路径
    image_folder_path = '/mnt/afs/smartbrain/zhukangkang/code/eval/data/eval_chosen'
    max_images = 10000  # 最大处理图片数量,可以根据需要调整
    save_path = '/mnt/afs/smartbrain/zhukangkang/code/eval/query/eval_query.json'


    # load vllm model
    llm = LLM(
        model=path,
        trust_remote_code=True,
        tensor_parallel_size=GPUS,
        # limit_mm_per_prompt={"image": 1},
    )
    tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
    stop_token_ids = [token_id for token_id in stop_token_ids if token_id is not None]
    sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens, stop_token_ids=stop_token_ids)


    # set default prompt
    response_format = {
        "question1":"问题1", 
        "question2":"问题2",
        "question3":"问题3",
    }
    question = '<image>\n'
    question +=f'''假设你是一个新手司机,对公路上的各种情况还不熟悉,目前遇到了如图所示的场景,你可能会提出那些问题,请输出3个问题。
要求:
1、问题尽量口语化一些,像司机在车上可能会问出的问题
2、输出要满足以下格式,确保回复能被python中json.loads解析,不要提供其他的解释与分析: 
{json.dumps(response_format, ensure_ascii=False, indent=4)}
'''
    print(question)

    
    
    
    image_list = get_samples_from_images(image_folder_path, max_images)
    
    previous_file = ""
    save_data_list = []
    for i in tqdm(range(0, len(image_list), batch_size)):
        
        start_index = i
        end_index = min(i+batch_size, len(image_list))
        batch_img = image_list[start_index:end_index]
        valid_data_list, inputs = [], []
        for temp_data in batch_img:
            temp_path = temp_data["abs_img_path"]
            try:
                image = connector.fetch_image("file://" + temp_path)
                messages = [
                    {
                        "role": "user",
                        "content": f"{question}"
                    }
                ]
                prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                inputs.append(
                    {
                        "prompt": prompt,
                        "multi_modal_data": {"image": image},
                    }
                )
                valid_data_list.append(temp_data)
            except Exception as e:
                print(f"Error processing image {temp_path}: {e}")
        if not inputs:
            print(f"No valid images in batch {start_index} to {end_index}. Skipping...")
            continue
        model_outputs = llm.generate(inputs, sampling_params=sampling_params)
        for valid_data, model_output in zip(valid_data_list, model_outputs):
            response = model_output.outputs[0].text
            try:
                match = re.search(r'\{.*?\}', response, re.DOTALL)
                if not match:
                    print("No valid JSON object found.")
                    continue
                response = json.loads(match.group(0))
            except Exception as e:
                print("json load error:{}".format(e))
            new_data = copy.deepcopy(valid_data)
            new_data["queries"] = response
            save_data_list.append(new_data)
        print("*"*50)
        print(f'Image: {valid_data["abs_img_path"]}\nAssistant: {model_output.outputs[0].text}')

        current_used_data_size = len(save_data_list)
        if current_used_data_size==batch_size or (current_used_data_size > 0 and current_used_data_size % save_steps == 0):
            step_save_path = save_path.replace(".json", "_steps_{}.json".format(current_used_data_size))
            save_json_data(step_save_path, save_data_list)
            print(f"Saved {current_used_data_size} items to {step_save_path}")
            if previous_file and os.path.exists(previous_file) and previous_file != step_save_path:
                os.remove(previous_file)
                print(f"Removed previous file: {previous_file}")
            previous_file = step_save_path

    print("total data size:{}".format(len(image_list)))
    print("validity data size:{}".format(len(save_data_list)))
    final_save_path = save_path.replace(".json", "_{}.json".format(len(save_data_list)))
    print("Saving final data to:", final_save_path)
    save_json_data(final_save_path, save_data_list)

        

你的pipeline脚本整体思路是正确的,但有几点细节建议,特别是prompt和batch写法,能进一步提升多模态推理的准确性和健壮性:


1. prompt构造建议

  • 你当前是:

    messages = [
        {
            "role": "user",
            "content": f"{question}"
        }
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    

    这样content是字符串,不符合多模态chat_template的最佳实践

  • 推荐写法(InternVL/Qwen2-VL等多模态模型):

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": question}
            ]
        }
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    

    这样才能保证图片token和文本token顺序、特殊token完全符合模型预期。官方示例


2. batch写法建议

  • 你的batch写法是标准的,inputs为list,每个元素是dict,包含"prompt"和"multi_modal_data"。
  • 推荐继续保持这种写法,不要把图片base64拼到prompt字符串里,只用multi_modal_data传递图片对象。

3. 其它细节建议

  • SamplingParamsstop_token_ids等参数设置没问题。
  • 如果有超长输入报错,需动态判断prompt+图片token总数,超限则截断或跳过。
  • 如果模型支持limit_mm_per_prompt,可在LLM初始化时加上limit_mm_per_prompt={"image": 1}

4. 总结

  • 只要prompt用chat_template的多模态格式,图片用multi_modal_data传递,batch为list,写法就是vLLM官方推荐的pipeline多模态推理范式。
  • 你的脚本只需把prompt部分改成上面推荐的格式即可。

需要完整代码示例或多轮对话/多图batch写法吗?

Sources:

我认为CLI部署的时候
我认为的batch处理是一次性发送一个批次的requests,每个request包含一条Query,然后我的客户端收到所有请求结果后开始下一个请求
但是因为我的客户端都是同步的,所以我的一整个批次的结束时间取决于批次里的那个最短板
如果我把所有的requests变成异步请求,是不是能够大幅度提升效率
是不是能够跟本地pipeline媲美

你的理解是对的:CLI部署(API服务)下,如果客户端是同步批量请求(即每个request等到所有结果返回后才发下一批),整体耗时会被最慢的那个request拖慢,这就是“短板效应”。如果你把所有requests改成异步并发(如用asyncio或多线程),可以大幅提升吞吐和利用率,因为每个请求独立返回,整体等待时间≈最慢的那一批返回的时间,而不是所有串行等待的总和。

但即使用异步并发,CLI API服务的极限效率仍然低于本地pipeline,因为API服务有网络、序列化、进程间通信等额外开销,无法完全达到pipeline的极致吞吐。官方issue明确指出pipeline模式始终更快,但异步并发能让API服务接近pipeline的效率,尤其在高并发场景下。

需要异步并发代码示例或性能对比数据吗?

Sources:

我的脚本还有什么提升速度的办法吗?比如关于参数啊什么的,或者说template或者是图片插入方式等等
脚本:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Stage-3 Valid – InternVL-78B 本地 vLLM 批推理版本
=================================================
相比原来的 client/HTTP 版本,本脚本直接在本地通过 vLLM 引擎调用 `LLM.generate()`,
可显著提升推理吞吐。

核心流程保持不变:
1. 逐条构建提示词 (宽松评估标准)
2. 批量调用 vLLM 生成判断
3. 解析判断 → 置信度 / 正确与否
4. 保存 valid / invalid / report

仅保留最关键功能,去掉远程 API / checkpoint,保持脚本精简易跑。
"""

from __future__ import annotations

import os
import sys
import argparse
import logging
import json
import time
from dataclasses import dataclass
from typing import List, Dict, Tuple

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams, EngineArgs  # type: ignore

# ---------------------------------------------------------------------------
# 共用结构 – 直接在本文件中重新实现所需工具
# ---------------------------------------------------------------------------

import re
import copy
import io
import requests
from PIL import Image  # type: ignore
from tqdm import tqdm


@dataclass
class ValidationResult:
    """单条验证输出"""

    judgment: str
    confidence: float
    is_correct: bool
    raw_output: str
    reasoning: str


class JudgmentParser:
    """解析模型自然语言输出,得到宽松判定"""

    correct_indicators = {
        "正确",
        "基本正确",
        "大致正确",
        "部分正确",
        "大部分正确",
        "总体正确",
        "符合要求",
        "accurate",
        "mostly correct",
        "partially correct",
        "true",
        "yes",
    }

    @staticmethod
    def parse_judgment(resp: str) -> Tuple[bool, float, str, str]:
        if not resp:
            return False, 0.0, "无响应", ""

        clean = JudgmentParser._clean(resp)
        judgment = JudgmentParser._extract_judgment(clean)
        confidence = JudgmentParser._extract_confidence(clean)
        reasoning = JudgmentParser._extract_reasoning(clean)
        is_correct = JudgmentParser._is_correct(judgment)
        return is_correct, confidence, judgment, reasoning

    @staticmethod
    def _clean(text: str) -> str:
        # 移除特殊标记和vLLM日志污染
        text = re.sub(r"<\|.*?\|>", "", text)
        # 移除vLLM性能统计信息
        text = re.sub(r'est\. speed input:.*?toks/s.*?output:.*?toks/s.*?\]', '', text, flags=re.DOTALL)
        text = re.sub(r'\[[^\]]*toks/s[^\]]*\]', '', text, flags=re.DOTALL)
        text = re.sub(r'\d+%\|[█▉▊▋▌▍▎▏]*\|.*?\[.*?\]', '', text)
        return re.sub(r"\s+", " ", text).strip()

    @staticmethod
    def _extract_judgment(text: str) -> str:
        """改进的判断提取,避免过早截断"""
        # 方法1:先尝试提取"判断:"到下一个字段之间的完整内容
        patterns = [
            # 匹配"判断:"到"置信度:"之间的内容(贪婪匹配,允许换行)
            r"判断[::]\s*(.+?)(?=\s*置信度[::])",
            # 匹配"判断:"到"理由:"之间的内容
            r"判断[::]\s*(.+?)(?=\s*理由[::])",
            # 匹配"判断:"到字符串结尾的内容(如果没有其他字段)
            r"判断[::]\s*(.+?)$",
            # 兜底:匹配"判断:"后面的一行内容
            r"判断[::]\s*([^\n\r]+)",
        ]
        
        for pattern in patterns:
            m = re.search(pattern, text, re.I | re.S | re.M)
            if m:
                result = m.group(1).strip()
                # 清理可能的尾部干扰内容
                result = re.sub(r'\s*(置信度|理由)[::].*$', '', result, flags=re.I | re.S)
                if result:
                    return result
        
        # 方法2:如果没有找到"判断:",尝试从文本开头提取判断性内容
        # 寻找可能的判断词汇
        judgment_words = ["正确", "错误", "基本正确", "部分正确", "大致正确", "完全正确", "不正确"]
        lines = text.split('\n')
        for line in lines:
            line = line.strip()
            if any(word in line for word in judgment_words) and len(line) < 200:  # 避免太长的行
                return line
        
        # 方法3:兜底返回原文本的前100个字符
        return (text[:100] + "...") if len(text) > 100 else text

    @staticmethod
    def _extract_confidence(text: str) -> float:
        """改进的置信度提取"""
        # 尝试多种模式匹配置信度
        patterns = [
            r"置信度[::]\s*(\d*\.?\d+)",
            r"confidence[::]?\s*(\d*\.?\d+)",
            r"信心[::]\s*(\d*\.?\d+)",
            r"(\d+\.?\d*)\s*[%%]",  # 百分比形式
        ]
        
        for pattern in patterns:
            m = re.search(pattern, text, re.I)
            if m:
                try:
                    val = float(m.group(1))
                    # 如果是百分比形式,转换为小数
                    if "%" in m.group(0) or "%" in m.group(0):
                        val = val / 100.0
                    return min(val, 1.0)
                except Exception:
                    continue
        
        # 粗略估计(保持原逻辑)
        lower = text.lower()
        if any(k in lower for k in ("完全正确", "absolutely", "显然", "确定")):
            return 0.95
        if any(k in lower for k in ("部分", "partially", "基本", "大致")):
            return 0.75
        if any(k in lower for k in ("错误", "incorrect", "wrong", "不正确")):
            return 0.8  # 错误判断也可能有较高置信度
        return 0.7

    @staticmethod
    def _extract_reasoning(text: str) -> str:
        """改进的理由提取"""
        patterns = [
            # 匹配"理由:"到文本结尾的内容
            r"理由[::]\s*(.+)$",
            # 匹配"原因:"
            r"原因[::]\s*(.+)$",
            # 匹配"because"
            r"because\s*[::]?\s*(.+)$",
            # 匹配"解释:"
            r"解释[::]\s*(.+)$",
        ]
        
        for pattern in patterns:
            m = re.search(pattern, text, re.I | re.S)
            if m:
                result = m.group(1).strip()
                if result:
                    return result
        
        # 如果没有找到明确的理由字段,尝试从文本中找到解释性内容
        lines = text.split('\n')
        for line in lines:
            line = line.strip()
            # 寻找包含解释性词汇的行
            if any(word in line.lower() for word in ["因为", "由于", "所以", "because", "since", "as"]) and len(line) > 10:
                return line
        
        return "无详细理由"

    @staticmethod
    def _is_correct(judgment: str) -> bool:
        j = judgment.lower()
        if any(ind in j for ind in JudgmentParser.correct_indicators):
            return not any(neg in j for neg in ("不正确", "错误", "incorrect", "wrong", "false"))
        return False


# ---------------- 数据集 I/O ----------------


def load_dataset(path: str) -> List[Dict]:
    logger.info(f"加载数据集: {path}")
    try:
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
    except Exception as e:
        logger.error(f"读取失败: {e}")
        sys.exit(1)

    if not isinstance(data, list):
        logger.error("顶层应为列表")
        sys.exit(1)

    processed: List[Dict] = []
    for item in data:
        if not isinstance(item, dict):
            continue

        # 兼容字段
        image = item.get("image") or item.get("image_path")
        q = item.get("question") or item.get("problem")
        a = item.get("answer") or item.get("gt_answer")

        if not (image and q and a):
            continue

        processed.append({"image": image, "question": q, "answer": a, **item})

    logger.info(f"标准化后样本数: {len(processed)}/{len(data)}")
    return processed


def save_json(path: str, data: List[Dict]):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    logger.info(f"保存 {len(data)} 条 → {path}")


def clean_validation_data(data: List[Dict]) -> List[Dict]:
    return [{k: v for k, v in item.items() if k != "gt_valid"} for item in data]


def generate_report(correct: List[Dict], incorrect: List[Dict], out_dir: str, fname: str, start_time: float, api: str):
    total = len(correct) + len(incorrect)
    duration = time.time() - start_time
    
    # 避免除零错误
    if total > 0:
        correct_pct = len(correct) / total * 100
        incorrect_pct = len(incorrect) / total * 100
    else:
        correct_pct = 0.0
        incorrect_pct = 0.0
    
    report = (
        f"Stage-3 Valid 报告\n====================\n"
        f"文件: {fname}\n总样本: {total}\n"
        f"正确: {len(correct)}  ({correct_pct:.2f}%)\n"
        f"错误: {len(incorrect)} ({incorrect_pct:.2f}%)\n"
        f"耗时: {duration:.2f}s\n"
        f"模式: 本地 vLLM pipeline\n"
    )
    path = os.path.join(out_dir, f"{fname}_report.txt")
    with open(path, "w", encoding="utf-8") as f:
        f.write(report)
    logger.info(f"报告已保存 → {path}")


logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
)

# ---------------------------------------------------------------------------
# 配置
# ---------------------------------------------------------------------------
@dataclass
class ValidationConfig:
    model_path: str = "/mnt/afs/share_models/git_models/OpenGVLab/InternVL3-78B"
    batch_size: int = 256
    temperature: float = 0.0
    top_p: float = 1.
    tensor_parallel_size: int = 8       # 默认使用 8 张 GPU 做张量并行

# ---------------------------------------------------------------------------
# 本地 vLLM 推理器
# ---------------------------------------------------------------------------
class InternVLPipelineInference:
    """本地 vLLM 批推理封装。"""

    def __init__(self, cfg: ValidationConfig):
        self.cfg = cfg

        logger.info("初始化 tokenizer …")
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_path, trust_remote_code=True, use_fast=False)

        logger.info("初始化 vLLM 引擎 … (耗时 1-3 分钟)")
        # 兼容旧配置,按需动态加入可选参数
        # 自动检测可用GPU数量
        visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
        if visible_devices:
            # 从 CUDA_VISIBLE_DEVICES 中计算可用GPU数量
            gpu_count = len([x.strip() for x in visible_devices.split(",") if x.strip()])
            tp_size = gpu_count
            logger.info(f"检测到 CUDA_VISIBLE_DEVICES={visible_devices},自动设置 tensor_parallel_size={tp_size}")
        else:
            # 使用配置中的默认值或环境变量 WORLD_SIZE
            tp_size = int(os.environ.get("WORLD_SIZE", cfg.tensor_parallel_size))
            logger.info(f"使用配置的 tensor_parallel_size={tp_size}")

        eng_kwargs = dict(
            model=cfg.model_path,
            tensor_parallel_size=max(1, tp_size),
            trust_remote_code=True,
            dtype="half",  # 使用半精度提升显存利用率
            enable_chunked_prefill=True,  # 启用 chunked prefill
            enable_prefix_caching=True,   # 启用 prefix caching
        )

        if hasattr(cfg, "max_num_batched_tokens"):
            eng_kwargs["max_num_batched_tokens"] = cfg.max_num_batched_tokens

        eng_args = EngineArgs(**eng_kwargs)
        self.llm = LLM(**eng_args.__dict__)

        # 按照 get_query.py 的方式设置 stop_tokens
        stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
        stop_token_ids = [self.tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
        stop_token_ids = [token_id for token_id in stop_token_ids if token_id is not None]

        self.sp = SamplingParams(
            temperature=cfg.temperature,
            top_p=cfg.top_p,
            max_tokens=1024,  # 添加最大输出长度限制,与vadlid版本保持一致
            stop_token_ids=stop_token_ids,
        )

    # -------------------- prompt 构建 --------------------
    def _build_prompt(self, sample: Dict) -> str:
        question = sample.get("question", "")
        answer = sample.get("answer", "")

        # 指令部分
        instr = (
            "你是一个严谨的、注重事实的多模态问答评估员。请基于图片内容,判断\"待判断答案\"是否准确地回答了\"问题\"。\n\n"
            "**要求:**\n"
            "1. **事实为先**: 你的判断必须严格基于图片内容,不能有任何想象或推断。\n"
            "2. **宽松标准**: '完全正确'、'基本正确'、'大致正确'、'部分正确'都视为【正确】。只有'完全错误'、'严重错误'、'明显不符'才视为【错误】。\n"
            "3. **输出格式必须如下**:\n\n"
            "判断:[正确/基本正确/部分正确/错误]\n"
            "置信度:[0.0-1.0]\n"
            "理由:[简要说明判断依据,指出答案中正确或错误的关键点]\n\n"
            "--- 以下是待评估内容 ---\n"
            f"问题:{question}\n待判断答案:{answer}"
        )

        # 按照 get_query.py 的方式构建 messages 和 prompt
        messages = [
            {
                "role": "user",
                "content": f"<image>\n{instr}"
            }
        ]

        return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    # -------------------- 核心批推理 --------------------
    def validate_batch(self, batch_data: List[Dict]) -> List[ValidationResult]:
        valid_data_list = []
        inputs = []
        
        for sample in batch_data:
            prompt = self._build_prompt(sample)
            img_path = sample.get("image", "")
            
            if not img_path:
                logger.warning("未提供图像路径,跳过该样本")
                continue

            try:
                # 同时支持本地路径和URL
                if img_path.startswith("http"):
                    # HTTP/HTTPS URL
                    resp = requests.get(img_path, timeout=15)
                    resp.raise_for_status()
                    image = Image.open(io.BytesIO(resp.content)).convert("RGB")
                else:
                    # 本地文件路径
                    image = Image.open(img_path).convert("RGB")
                
                inputs.append({
                    "prompt": prompt,
                    "multi_modal_data": {"image": image},
                })
                valid_data_list.append(sample)
                
            except Exception as e:
                logger.warning(f"加载图片失败: {img_path} - {e}")
                continue

        if not inputs:
            logger.warning("批次中没有有效图片,返回空结果")
            return [], []

        # 批量推理
        outputs = self.llm.generate(inputs, self.sp)

        results: List[ValidationResult] = []
        for valid_data, out in zip(valid_data_list, outputs):
            raw = out.outputs[0].text.strip()
            is_corr, conf, judgment, reasoning = JudgmentParser.parse_judgment(raw)
            results.append(
                ValidationResult(
                    judgment=judgment,
                    confidence=conf,
                    is_correct=is_corr,
                    raw_output=raw,
                    reasoning=reasoning,
                )
            )
        
        return results, valid_data_list

# ---------------------------------------------------------------------------
# 入口逻辑 (简化,无 checkpoint)
# ---------------------------------------------------------------------------

def parse_args():
    p = argparse.ArgumentParser("Stage-3 Valid – 本地 vLLM Pipeline 版")
    p.add_argument("--input", required=True, help="输入 JSON 文件")
    p.add_argument(
        "--output",
        default="/mnt/afs/smartbrain/zhukangkang/code/common_data_pipeline/stage1_cls/output_stage1_pipeline",
        help="输出根目录 (默认: /mnt/afs/smartbrain/zhukangkang/code/common_data_pipeline/stage1_cls/output_stage1_pipeline)",
    )
    p.add_argument("--model", default="/mnt/afs/share_models/git_models/OpenGVLab/InternVL3-78B", help="模型路径")
    p.add_argument("--batch-size", type=int, help="批推理大小 (默认: 512)")
    p.add_argument("--tp-size", type=int, default=8, help="Tensor parallel size (GPU 卡数,默认 8)")
    p.add_argument("--debug", action="store_true")
    return p.parse_args()


def main():
    args = parse_args()
    if args.debug:
        logging.getLogger().setLevel(logging.DEBUG)

    cfg = ValidationConfig(
        model_path=args.model,
        tensor_parallel_size=args.tp_size,
    )
    
    # 如果命令行指定了batch_size,则覆盖默认值
    if args.batch_size is not None:
        cfg.batch_size = args.batch_size

    inference = InternVLPipelineInference(cfg)

    data = load_dataset(args.input)
    if not data:
        logger.error("数据为空,退出")
        sys.exit(1)

    # 输出目录遵循旧脚本逻辑: <output_root>/<input_filename>/
    input_fname = os.path.splitext(os.path.basename(args.input))[0]
    run_out_dir = os.path.join(args.output, input_fname)
    os.makedirs(run_out_dir, exist_ok=True)

    correct: List[Dict] = []
    incorrect: List[Dict] = []

    total = len(data)
    logger.info(f"共有 {total} 条样本,batch_size={cfg.batch_size}")

    start_time = time.time()
    
    for i in tqdm(range(0, total, cfg.batch_size)):
        batch = data[i : i + cfg.batch_size]
        logger.info(f"处理 {i+1}‒{min(i+cfg.batch_size, total)} …")
        
        try:
            results, valid_samples = inference.validate_batch(batch)
            
            for sample, res in zip(valid_samples, results):
                sample["gt_valid"] = {
                    "validation": {
                        "model": "InternVL-78B-pipeline",
                        "judgment": res.judgment,
                        "confidence": res.confidence,
                        "reasoning": res.reasoning,
                        "is_correct": res.is_correct,
                        "raw_output": res.raw_output,
                    }
                }
                (correct if res.is_correct else incorrect).append(sample)
                
        except Exception as e:
            logger.error(f"批次处理失败: {e}")
            continue

    # 保存,与老脚本命名保持一致
    valid_path = os.path.join(run_out_dir, f"{input_fname}_valid.json")
    invalid_path = os.path.join(run_out_dir, f"{input_fname}_invalid.json")

    save_json(valid_path, clean_validation_data(correct))
    save_json(invalid_path, incorrect)

    # 报告
    generate_report(correct, incorrect, run_out_dir, input_fname, start_time, "local-pipeline")

    logger.info("✅ 全部完成!")


if __name__ == "__main__":
    main()