推理Qwen3-VL-235B-A22B-Instruct-FP8精度有问题

我当前的推理脚本推理Qwen3-VL-235B-A22B-Instruct-FP8,但是发小效果比较差,精度有问题,会有描述错误。我不知道我的脚本有没有问题。特别是在推理代码的构造方式上面。

import os
import pdb
import time
import uuid
import queue
import random
import multiprocessing as mp
from functools import partial

# import numpy as np
from PIL import Image

import json_repair
import imageio.v3 as iio
from vllm import LLM, SamplingParams
from transformers import AutoProcessor
from qwen_vl_utils import process_vision_info

from src.utils.logger import logger
from src.utils.io_utils import get_presigned_url

os.environ["VLLM_ENABLE_MMAP_CACHE"] = "1"


class VideoLoader:
    def __init__(self, num_workers=None):
        self.num_workers = num_workers or mp.cpu_count()
        self.pool = mp.Pool(processes=self.num_workers)

    def close(self):
        try:
            self.pool.close()
            self.pool.join()
        except Exception:
            pass

    def load_videos(self, func, video_paths, sample_frequency=1):
        if not isinstance(video_paths, (list, tuple)):
            raise ValueError("video_paths must be a list")
        
        t0 = time.time()
        worker_func = partial(func, sample_frequency=sample_frequency)
        results = self.pool.map(worker_func, video_paths)
        logger.info('load {} number video cost time:{}s'.format(len(video_paths), time.time() - t0))
        return results

def load_video(video_info, sample_frequency=1):
    url = video_info['url']
    rel_start = video_info['start_frame']
    rel_end = video_info['end_frame']

    try:
        meta = iio.imopen(uri=url, io_mode="r", plugin="pyav").metadata()
    except Exception as e:
        print(f"[DECODE ERROR] {url}: {e}")
        return []
    
    fps = meta.get("fps", 30)
    duration = meta.get("duration", 1)
    frames_num = fps * duration
    if rel_end >= frames_num:
        logger.warning('rel_end:{} >= frames_num: {}, url:{}'.format(rel_end, frames_num, url))
        return []
    
    step = int(fps / sample_frequency)
    get_frame_index = list(range(rel_start, rel_end + 1, step))
    if len(get_frame_index) == 0:
        logger.warning('get_frame_index is empty, url:{}'.format(url))
        return []
  
    video = []
    try:
        reader = iio.imiter(uri=url, plugin="pyav")
        for idx, frame in enumerate(reader):
            if idx in get_frame_index:
                video.append(Image.fromarray(frame).convert('RGB'))
            if idx >= max(get_frame_index):
                break
    except Exception as e:
        print(f"[iio.imiter ERROR] {url}: {e}")
        return []
    return video


class ModelWorker():
    os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
    def __init__(self, cfg, req_queue: mp.Queue, res_queue: mp.Queue, init_event: mp.Event) -> None:
        self.cfg = cfg
        self.req_queue = req_queue
        self.res_queue = res_queue
        self.sampling_params = SamplingParams(
            temperature=0,         # 贪婪解码:最准确,无随机性,减少幻觉
            max_tokens=1024,
            stop_token_ids=[151645, 151643],  # <|im_end|> 和 <|endoftext|>
        )
        self.init_event = init_event

    def build_model(self, model_path):
        """
        构建并加载模型
        """
        logger.info(f"Loading model from {model_path}")
        self.llm = None
        t0 = time.time()
        self.llm = LLM(
            model=model_path,
            trust_remote_code=True,
            gpu_memory_utilization=0.9,
            enforce_eager=False,
            enable_expert_parallel=False,
            # tensor_parallel_size=torch.cuda.device_count(),
            tensor_parallel_size=4,
            seed=0,
            max_num_seqs=256,
            max_model_len=4096,
            max_num_batched_tokens=23000
        )
        logger.info('model init cost time:{}'.format(time.time() - t0))
        
    def get_text_from_outputs(self, outputs: list, run_types:list):
        outputs_size = len(outputs)
        step = len(run_types)
        if (outputs_size % step) != 0:
            logger.error('outputs_size: {}, run_types_size: {}'.format(outputs_size, step))
            return []

        results = []
        step = len(run_types)
        for i in range(0, outputs_size, step):
            rslt = {}
            for j in range(step):
                output = outputs[i + j]
                generated_text = output.outputs[0].text
                try:
                    response = json_repair.loads(generated_text)
                except Exception as e:
                    logger.error(f"Error: {str(e)}")
                    response = "Error: " + str(e)
                if isinstance(response, dict):
                    if 'caption' in response:
                        text = response.get("caption", "")
                    else:
                        text = response
                else:
                    text = ''
                rslt[run_types[j]] = text
            results.append(rslt)
        return results

    def run(self):
        logger.info("Model process start: loading model...")
        self.build_model(self.cfg.model_path)
        self.init_event.set()
        logger.info('init model: {} success'.format(self.cfg.model_path))
        while True:
            item = self.req_queue.get()
            if item is None:
                print("Model worker process exit")
                break

            uuid_value, run_types, batch_items, inputs = item
            print(f"[Model] infer job {uuid_value}")
            if self.llm == None:
                self.res_queue.put({"status":-1, "uuid": uuid_value, "batch_items": batch_items})
                continue
            
            t0 = time.time()
            outputs = self.llm.generate(inputs, sampling_params=self.sampling_params)
            logger.info('llm generate inputs size:{} cost time:{}s'.format(len(inputs), time.time() - t0))
            
            if len(outputs) != (len(run_types) * len(batch_items)):
                logger.error('generate outputs failed, outputs size:{}, batch_data_items:{}'.format(
                    len(outputs), len(batch_items)))
                self.res_queue.put({"status":-1, "uuid": uuid_value, "batch_items": batch_items})
                continue
            
            results = self.get_text_from_outputs(outputs, run_types)
            
            if len(results) != len(batch_items):
                logger.error('get text from outputs failed, results size:{} != batch_items:{}'.format(
                    len(results), len(batch_items)))
                self.res_queue.put({"status":-1, "uuid": uuid_value, "batch_items": batch_items})
                continue
            
            for idx in range(len(batch_items)):
                results[idx]['url'] = batch_items[idx]['url']
            # print("model result:{}, uuid_value:{}".format(results, uuid_value))
            rsp = {
                'status':0,
                'uuid': uuid_value,
                'batch_items': batch_items,
                'run_types': run_types,
                'model_output': results
            }
            self.res_queue.put(rsp)

def start_model_worker(cfg, input_queue: mp.Queue, output_queue: mp.Queue, init_event: mp.Event):
    worker = ModelWorker(cfg, input_queue, output_queue, init_event)
    worker.run()


class Caption():
    def __init__(self, cfg, worker_num=20, cos_client=None,
            tos_client=None, aoss_client=None):
        self.IMAGENET_MEAN = cfg.mean
        self.IMAGENET_STD = cfg.std
        self.processor = AutoProcessor.from_pretrained(cfg.model_path)
        self.processor.tokenizer.padding_side = "left"
        
        self.tos_client = tos_client
        self.aoss_client = aoss_client
        self.cos_client = cos_client
        
        self.video_loader = VideoLoader(worker_num)
        self.input_queue = mp.Queue(maxsize=100)
        self.output_queue = mp.Queue(maxsize=100)
        
        self.model_init_event = mp.Event()
        self.model_process = mp.Process(
            target=start_model_worker,
            args=(cfg, self.input_queue, self.output_queue, self.model_init_event),
            # daemon=True
        )
        self.model_process.start()
        
        self.prompt_dict = dict(
           caption =
            """
You are an expert video captioning model. Your task is to generate a highly detailed and accurate caption for the given video.

Requirements:
1. Describe everything you observe in the video with rich details - be specific about colors, shapes, sizes, positions, movements, and any visible text or patterns.
2. Capture the temporal flow: describe what happens from beginning to end in chronological order.
3. Be precise and factual: only describe what is clearly visible, avoid guessing or assumptions.
4. Write as a single flowing paragraph.
5. Output in JSON format with one field: "caption".

Example output:
{"caption": "A woman with long brown hair wearing a light pink apron over a white t-shirt stands at a kitchen counter with white marble surface. She picks up a red bell pepper from a ceramic bowl, places it on a wooden cutting board, and slices it into thin strips using a chef's knife. Sunlight from the window illuminates the scene, casting soft shadows across the counter. A small potted plant sits on the windowsill behind her."}
""",

            tag=
            """
            Analyze the input video to complete two tasks: 1) Classify it into one primary category from the following options: human/ego/robot/physics/unsafe/general; 2) Generate specific, non-vague sub-tags aligned with the selected category’s guidelines.
            Category Definitions:
            "human": Third-person perspective of humans ; sub-tags keys include: actions, occupation, gender, age, context.
            "ego": First-person perspective; sub-tags keys include: scenes, environments, interaction behaviors.
            "robot": Depicts physical, tangible robots interacting with the real world; sub-tags keys include: scenes, actions.
            "physics": Clearly demonstrates observable physical laws or phenomena; sub-tags keys include: physics principles (select from or align with the following: Mechanical Motion, Fluid Properties, Energy Conversion, Optical Phenomena, Spatial Relationships, Material Deformation, Electromagnetic Effects, Thermodynamic Changes, Causal Linkage, Flexible Motion, Microscopic Motion, Macroscopic Motion, Cosmic Physics) 
            "unsafe": Contains violence, pornography, or politically sensitive content; sub-tags keys include: specific type.
            "general": No fit for the above categories; sub-tags keys include: content type (choose one from game, animation, movie, short video, others), scene.
            Output Format (strict): {"primary_category": "[Selected Category]", "sub_tags": {"sub-tag1_key": "sub-tag1_value", "sub-tag2_key": "sub-tag2_value", ...}}
            Example output: {"primary_category": "general", "sub_tags": {"content_type": "animation"}}
            """
            )

    def __del__(self):
        """析构时关闭进程池"""
        if self.video_loader is not None:
            self.video_loader.close()
    
    def generate_messages(self, batch_items: list, video_imgs:list, run_types:list):
        if len(batch_items) != len(video_imgs):
            logger.error('batch_items size:{} != video_imgs size:{}'.format(len(batch_items), len(video_imgs)))
            return []
        messages = []
        invalid_urls = []
        # start = time.time()
        for idx in range(len(batch_items)):
            batch_item = batch_items[idx]
            video = video_imgs[idx]
            if len(video) == 0:
                logger.error('generate messages video path: {} is error'.format(batch_item['url']))
                invalid_urls.append(batch_item['url'])
                continue
            
            for run_type in run_types:
                text = ''
                if run_type == 'caption':
                    text = self.prompt_dict['caption']
                elif run_type == 'tag':
                    text = self.prompt_dict['tag']
                else:
                    logger.error('Error run_type:{} vaule is error'.format(run_type))

                message = [
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "video",
                                "video": video,
                                "min_pixels": 64 * 32 * 32,
                                "max_pixels": 256 * 32 * 32,
                                "total_pixels": 10240 * 32 * 32
                            },
                            {"type": "text", "text": text},
                        ],
                    }]
                messages.append(message)
        return messages, invalid_urls


    def prepare_inputs_for_vllm(self, messages):
        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        # qwen_vl_utils 0.0.14+ reqired
        image_inputs, video_inputs, video_kwargs = process_vision_info(
            messages,
            image_patch_size=self.processor.image_processor.patch_size,
            return_video_kwargs=True,
            return_video_metadata=True
        )
        # logger.info(f"video_kwargs: {video_kwargs}")  # 'do_sample_frames': False

        mm_data = {}
        if image_inputs is not None:
            mm_data['image'] = image_inputs
        if video_inputs is not None:
            mm_data['video'] = video_inputs

        return {
            'prompt': text,
            'multi_modal_data': mm_data,
            'mm_processor_kwargs': video_kwargs
        }

    def get_extract_img_from_video(self, batch_items:list):
        video_paths = []
        for batch_item in batch_items:
            if 's3' in batch_item['url'] or 'tos' in batch_item['url'] or 'cos' in batch_item['url'] :
                url = get_presigned_url(batch_item['url'], tos_client=self.tos_client,
                    aoss_client=self.aoss_client, cos_client=self.cos_client)
            else:
                url = batch_item['url']
            video_info = {
                'url':url,
                'start_frame':batch_item['start_frame'],
                'end_frame':batch_item['end_frame']
            }
            video_paths.append(video_info)
            
        return self.video_loader.load_videos(load_video, video_paths)
    
    def run_single(self, batch_items, uuid_value, run_types) -> bool:
        logger.info('text_labeling run_types:{}, batch_items size:{}'.format(run_types, len(batch_items)))
        if not isinstance(batch_items, list):
            batch_items = [batch_items]
        
        video_imgs = self.get_extract_img_from_video(batch_items)
        messages, invalid_urls = self.generate_messages(batch_items, video_imgs, run_types)
        if len(messages) == 0:
            logger.error('generate_messages failed')
            return False
        logger.info('invalid_urls: {}'.format(invalid_urls))
        
        valid_batch_items = []
        for batch_item in batch_items:
            if batch_item['url'] in invalid_urls:
                continue
            valid_batch_items.append(batch_item)
        t0 = time.time()
        inputs = [self.prepare_inputs_for_vllm(message) for message in messages]
        logger.info('prepare_inputs_for_vllm messages size:{} cost time:{}s'.format(len(messages), time.time() - t0))
        self.input_queue.put((uuid_value, run_types, valid_batch_items, inputs))
        logger.info("input_queue put uuid_value:{}".format(uuid_value))
        return True
    
    def run(self, data_items:list, batch_size:int, run_types=['caption', 'tag']):
        if  (len(data_items) == 0) or (batch_size <= 0):
            logger.error('data_items size:{} or batch_size:{} is error'.format(len(data_items), batch_size))
            return []
        
        total_items = len(data_items)
        uuids = []
        for idx in range(0, total_items, batch_size):
            uuid_value = str(uuid.uuid4())
            batch_items = data_items[idx:idx + batch_size]
            print('batch_items size:{}'.format(len(batch_items)))
            if not self.run_single(batch_items, uuid_value, run_types):
                logger.error('Error run batch_items:{} failed'.format(batch_items))
                continue
            uuids.append(uuid_value)
        
        overtime = 1000
        if not self.model_init_event.wait(timeout=overtime):
            raise RuntimeError("等待模型初始化超时({}s)".format(overtime))
        
        # 获取运行的结果
        responses = []
        overtime = 150
        had_once_forward = False
        begin_time = time.time()
        while True:
            output = None
            if had_once_forward:
                begin_time = time.time()
                had_once_forward = False
            try:
                output = self.output_queue.get(timeout=1)
            except queue.Empty:
                output = None
            except Exception as e:
                print(f"获取队列数据时发生错误: {e}")
                break
        
            if output is None:
                cost_time = time.time() - begin_time
                logger.info('run cost time:{}'.format(cost_time))
                if cost_time >= overtime:
                    logger.error('It timed out, so it exited.')
                    break
                continue
            had_once_forward = True

            uuids.remove(output['uuid'])
            if output['status'] != 0:
                logger.error('output status is error, batch_items:{}'.format(output['batch_items']))
            else:
                # TODO 还是先需要保存一份内容, 在挂载盘上
                responses.extend(output['model_output'])
            # print('run model_output:{}'.format(output['model_output']))
            if len(uuids) == 0:
                logger.info('All tasks have been completed.')
                break

        self.stop()
        if len(uuids) != 0:
            logger.error('not get all model result, uuids:{}'.format(uuids))
        return responses
    
    def stop(self):
        self.input_queue.put(None)


你的推理脚本整体结构和调用方式与vLLM官方推荐的多进程/多模态推理范式基本一致,主流程没有明显结构性错误。你采用了process_vision_infoAutoProcessor.apply_chat_template等标准方法,推理输入格式也符合Qwen3-VL的要求,模型加载和多进程通信部分也较为规范。不过,推理精度差、描述错误主要不是由脚本结构导致,而更可能与以下因素有关:

  1. 视频帧采样与预处理:你的视频帧提取逻辑(如sample_frequencymin_pixelsmax_pixels等)会直接影响模型输入的有效信息量,若采样过稀或分辨率设置不合理,模型理解力会下降。2. prompt设计:你的prompt虽然详细,但如果视频内容复杂或与prompt描述不符,模型输出也会偏离。3. 模型本身与参数:Qwen3-VL-235B-A22B-Instruct-FP8本身的FP8量化精度、模型权重、推理参数(如max_model_lenmax_num_batched_tokens等)也会影响表现。4. 多进程/多卡配置tensor_parallel_sizemax_num_seqs等参数需与硬件资源匹配,否则可能导致推理效率低或OOM。

建议:1)检查视频帧采样策略,确保采样帧能覆盖视频主要内容;2)适当调整min_pixelsmax_pixels等参数,保证输入分辨率不过低;3)尝试更换或精简prompt,排查是否为prompt工程问题;4)如有条件,尝试用vLLM官方vision_language.py示例Qwen3-VL官方推理脚本对比输出,排查模型本身问题。

需要更详细的逐步排查建议吗?

Sources:

Hint: Mention @RunLLM in the post for followups.

如果stop_token_ids=[]的话,会传入默认值吗,还是会影响推理精度

代码如下
class ModelWorker():

def _init_(self, cfg, req_queue: mp.Queue, res_queue: mp.Queue, init_event: mp.Event) → None:

self.cfg = cfg

self.req_queue = req_queue

self.res_queue = res_queue

self.sampling_params = SamplingParams(

temperature=0.6,

top_p=0.9,

top_k=50,

repetition_penalty=1.05,

max_tokens=1024,

stop_token_ids=[],

)

self.init_event = init_event

@RunLLM