我当前的推理脚本推理Qwen3-VL-235B-A22B-Instruct-FP8,但是发小效果比较差,精度有问题,会有描述错误。我不知道我的脚本有没有问题。特别是在推理代码的构造方式上面。
import os
import pdb
import time
import uuid
import queue
import random
import multiprocessing as mp
from functools import partial
# import numpy as np
from PIL import Image
import json_repair
import imageio.v3 as iio
from vllm import LLM, SamplingParams
from transformers import AutoProcessor
from qwen_vl_utils import process_vision_info
from src.utils.logger import logger
from src.utils.io_utils import get_presigned_url
os.environ["VLLM_ENABLE_MMAP_CACHE"] = "1"
class VideoLoader:
def __init__(self, num_workers=None):
self.num_workers = num_workers or mp.cpu_count()
self.pool = mp.Pool(processes=self.num_workers)
def close(self):
try:
self.pool.close()
self.pool.join()
except Exception:
pass
def load_videos(self, func, video_paths, sample_frequency=1):
if not isinstance(video_paths, (list, tuple)):
raise ValueError("video_paths must be a list")
t0 = time.time()
worker_func = partial(func, sample_frequency=sample_frequency)
results = self.pool.map(worker_func, video_paths)
logger.info('load {} number video cost time:{}s'.format(len(video_paths), time.time() - t0))
return results
def load_video(video_info, sample_frequency=1):
url = video_info['url']
rel_start = video_info['start_frame']
rel_end = video_info['end_frame']
try:
meta = iio.imopen(uri=url, io_mode="r", plugin="pyav").metadata()
except Exception as e:
print(f"[DECODE ERROR] {url}: {e}")
return []
fps = meta.get("fps", 30)
duration = meta.get("duration", 1)
frames_num = fps * duration
if rel_end >= frames_num:
logger.warning('rel_end:{} >= frames_num: {}, url:{}'.format(rel_end, frames_num, url))
return []
step = int(fps / sample_frequency)
get_frame_index = list(range(rel_start, rel_end + 1, step))
if len(get_frame_index) == 0:
logger.warning('get_frame_index is empty, url:{}'.format(url))
return []
video = []
try:
reader = iio.imiter(uri=url, plugin="pyav")
for idx, frame in enumerate(reader):
if idx in get_frame_index:
video.append(Image.fromarray(frame).convert('RGB'))
if idx >= max(get_frame_index):
break
except Exception as e:
print(f"[iio.imiter ERROR] {url}: {e}")
return []
return video
class ModelWorker():
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
def __init__(self, cfg, req_queue: mp.Queue, res_queue: mp.Queue, init_event: mp.Event) -> None:
self.cfg = cfg
self.req_queue = req_queue
self.res_queue = res_queue
self.sampling_params = SamplingParams(
temperature=0, # 贪婪解码:最准确,无随机性,减少幻觉
max_tokens=1024,
stop_token_ids=[151645, 151643], # <|im_end|> 和 <|endoftext|>
)
self.init_event = init_event
def build_model(self, model_path):
"""
构建并加载模型
"""
logger.info(f"Loading model from {model_path}")
self.llm = None
t0 = time.time()
self.llm = LLM(
model=model_path,
trust_remote_code=True,
gpu_memory_utilization=0.9,
enforce_eager=False,
enable_expert_parallel=False,
# tensor_parallel_size=torch.cuda.device_count(),
tensor_parallel_size=4,
seed=0,
max_num_seqs=256,
max_model_len=4096,
max_num_batched_tokens=23000
)
logger.info('model init cost time:{}'.format(time.time() - t0))
def get_text_from_outputs(self, outputs: list, run_types:list):
outputs_size = len(outputs)
step = len(run_types)
if (outputs_size % step) != 0:
logger.error('outputs_size: {}, run_types_size: {}'.format(outputs_size, step))
return []
results = []
step = len(run_types)
for i in range(0, outputs_size, step):
rslt = {}
for j in range(step):
output = outputs[i + j]
generated_text = output.outputs[0].text
try:
response = json_repair.loads(generated_text)
except Exception as e:
logger.error(f"Error: {str(e)}")
response = "Error: " + str(e)
if isinstance(response, dict):
if 'caption' in response:
text = response.get("caption", "")
else:
text = response
else:
text = ''
rslt[run_types[j]] = text
results.append(rslt)
return results
def run(self):
logger.info("Model process start: loading model...")
self.build_model(self.cfg.model_path)
self.init_event.set()
logger.info('init model: {} success'.format(self.cfg.model_path))
while True:
item = self.req_queue.get()
if item is None:
print("Model worker process exit")
break
uuid_value, run_types, batch_items, inputs = item
print(f"[Model] infer job {uuid_value}")
if self.llm == None:
self.res_queue.put({"status":-1, "uuid": uuid_value, "batch_items": batch_items})
continue
t0 = time.time()
outputs = self.llm.generate(inputs, sampling_params=self.sampling_params)
logger.info('llm generate inputs size:{} cost time:{}s'.format(len(inputs), time.time() - t0))
if len(outputs) != (len(run_types) * len(batch_items)):
logger.error('generate outputs failed, outputs size:{}, batch_data_items:{}'.format(
len(outputs), len(batch_items)))
self.res_queue.put({"status":-1, "uuid": uuid_value, "batch_items": batch_items})
continue
results = self.get_text_from_outputs(outputs, run_types)
if len(results) != len(batch_items):
logger.error('get text from outputs failed, results size:{} != batch_items:{}'.format(
len(results), len(batch_items)))
self.res_queue.put({"status":-1, "uuid": uuid_value, "batch_items": batch_items})
continue
for idx in range(len(batch_items)):
results[idx]['url'] = batch_items[idx]['url']
# print("model result:{}, uuid_value:{}".format(results, uuid_value))
rsp = {
'status':0,
'uuid': uuid_value,
'batch_items': batch_items,
'run_types': run_types,
'model_output': results
}
self.res_queue.put(rsp)
def start_model_worker(cfg, input_queue: mp.Queue, output_queue: mp.Queue, init_event: mp.Event):
worker = ModelWorker(cfg, input_queue, output_queue, init_event)
worker.run()
class Caption():
def __init__(self, cfg, worker_num=20, cos_client=None,
tos_client=None, aoss_client=None):
self.IMAGENET_MEAN = cfg.mean
self.IMAGENET_STD = cfg.std
self.processor = AutoProcessor.from_pretrained(cfg.model_path)
self.processor.tokenizer.padding_side = "left"
self.tos_client = tos_client
self.aoss_client = aoss_client
self.cos_client = cos_client
self.video_loader = VideoLoader(worker_num)
self.input_queue = mp.Queue(maxsize=100)
self.output_queue = mp.Queue(maxsize=100)
self.model_init_event = mp.Event()
self.model_process = mp.Process(
target=start_model_worker,
args=(cfg, self.input_queue, self.output_queue, self.model_init_event),
# daemon=True
)
self.model_process.start()
self.prompt_dict = dict(
caption =
"""
You are an expert video captioning model. Your task is to generate a highly detailed and accurate caption for the given video.
Requirements:
1. Describe everything you observe in the video with rich details - be specific about colors, shapes, sizes, positions, movements, and any visible text or patterns.
2. Capture the temporal flow: describe what happens from beginning to end in chronological order.
3. Be precise and factual: only describe what is clearly visible, avoid guessing or assumptions.
4. Write as a single flowing paragraph.
5. Output in JSON format with one field: "caption".
Example output:
{"caption": "A woman with long brown hair wearing a light pink apron over a white t-shirt stands at a kitchen counter with white marble surface. She picks up a red bell pepper from a ceramic bowl, places it on a wooden cutting board, and slices it into thin strips using a chef's knife. Sunlight from the window illuminates the scene, casting soft shadows across the counter. A small potted plant sits on the windowsill behind her."}
""",
tag=
"""
Analyze the input video to complete two tasks: 1) Classify it into one primary category from the following options: human/ego/robot/physics/unsafe/general; 2) Generate specific, non-vague sub-tags aligned with the selected category’s guidelines.
Category Definitions:
"human": Third-person perspective of humans ; sub-tags keys include: actions, occupation, gender, age, context.
"ego": First-person perspective; sub-tags keys include: scenes, environments, interaction behaviors.
"robot": Depicts physical, tangible robots interacting with the real world; sub-tags keys include: scenes, actions.
"physics": Clearly demonstrates observable physical laws or phenomena; sub-tags keys include: physics principles (select from or align with the following: Mechanical Motion, Fluid Properties, Energy Conversion, Optical Phenomena, Spatial Relationships, Material Deformation, Electromagnetic Effects, Thermodynamic Changes, Causal Linkage, Flexible Motion, Microscopic Motion, Macroscopic Motion, Cosmic Physics)
"unsafe": Contains violence, pornography, or politically sensitive content; sub-tags keys include: specific type.
"general": No fit for the above categories; sub-tags keys include: content type (choose one from game, animation, movie, short video, others), scene.
Output Format (strict): {"primary_category": "[Selected Category]", "sub_tags": {"sub-tag1_key": "sub-tag1_value", "sub-tag2_key": "sub-tag2_value", ...}}
Example output: {"primary_category": "general", "sub_tags": {"content_type": "animation"}}
"""
)
def __del__(self):
"""析构时关闭进程池"""
if self.video_loader is not None:
self.video_loader.close()
def generate_messages(self, batch_items: list, video_imgs:list, run_types:list):
if len(batch_items) != len(video_imgs):
logger.error('batch_items size:{} != video_imgs size:{}'.format(len(batch_items), len(video_imgs)))
return []
messages = []
invalid_urls = []
# start = time.time()
for idx in range(len(batch_items)):
batch_item = batch_items[idx]
video = video_imgs[idx]
if len(video) == 0:
logger.error('generate messages video path: {} is error'.format(batch_item['url']))
invalid_urls.append(batch_item['url'])
continue
for run_type in run_types:
text = ''
if run_type == 'caption':
text = self.prompt_dict['caption']
elif run_type == 'tag':
text = self.prompt_dict['tag']
else:
logger.error('Error run_type:{} vaule is error'.format(run_type))
message = [
{
"role": "user",
"content": [
{
"type": "video",
"video": video,
"min_pixels": 64 * 32 * 32,
"max_pixels": 256 * 32 * 32,
"total_pixels": 10240 * 32 * 32
},
{"type": "text", "text": text},
],
}]
messages.append(message)
return messages, invalid_urls
def prepare_inputs_for_vllm(self, messages):
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# qwen_vl_utils 0.0.14+ reqired
image_inputs, video_inputs, video_kwargs = process_vision_info(
messages,
image_patch_size=self.processor.image_processor.patch_size,
return_video_kwargs=True,
return_video_metadata=True
)
# logger.info(f"video_kwargs: {video_kwargs}") # 'do_sample_frames': False
mm_data = {}
if image_inputs is not None:
mm_data['image'] = image_inputs
if video_inputs is not None:
mm_data['video'] = video_inputs
return {
'prompt': text,
'multi_modal_data': mm_data,
'mm_processor_kwargs': video_kwargs
}
def get_extract_img_from_video(self, batch_items:list):
video_paths = []
for batch_item in batch_items:
if 's3' in batch_item['url'] or 'tos' in batch_item['url'] or 'cos' in batch_item['url'] :
url = get_presigned_url(batch_item['url'], tos_client=self.tos_client,
aoss_client=self.aoss_client, cos_client=self.cos_client)
else:
url = batch_item['url']
video_info = {
'url':url,
'start_frame':batch_item['start_frame'],
'end_frame':batch_item['end_frame']
}
video_paths.append(video_info)
return self.video_loader.load_videos(load_video, video_paths)
def run_single(self, batch_items, uuid_value, run_types) -> bool:
logger.info('text_labeling run_types:{}, batch_items size:{}'.format(run_types, len(batch_items)))
if not isinstance(batch_items, list):
batch_items = [batch_items]
video_imgs = self.get_extract_img_from_video(batch_items)
messages, invalid_urls = self.generate_messages(batch_items, video_imgs, run_types)
if len(messages) == 0:
logger.error('generate_messages failed')
return False
logger.info('invalid_urls: {}'.format(invalid_urls))
valid_batch_items = []
for batch_item in batch_items:
if batch_item['url'] in invalid_urls:
continue
valid_batch_items.append(batch_item)
t0 = time.time()
inputs = [self.prepare_inputs_for_vllm(message) for message in messages]
logger.info('prepare_inputs_for_vllm messages size:{} cost time:{}s'.format(len(messages), time.time() - t0))
self.input_queue.put((uuid_value, run_types, valid_batch_items, inputs))
logger.info("input_queue put uuid_value:{}".format(uuid_value))
return True
def run(self, data_items:list, batch_size:int, run_types=['caption', 'tag']):
if (len(data_items) == 0) or (batch_size <= 0):
logger.error('data_items size:{} or batch_size:{} is error'.format(len(data_items), batch_size))
return []
total_items = len(data_items)
uuids = []
for idx in range(0, total_items, batch_size):
uuid_value = str(uuid.uuid4())
batch_items = data_items[idx:idx + batch_size]
print('batch_items size:{}'.format(len(batch_items)))
if not self.run_single(batch_items, uuid_value, run_types):
logger.error('Error run batch_items:{} failed'.format(batch_items))
continue
uuids.append(uuid_value)
overtime = 1000
if not self.model_init_event.wait(timeout=overtime):
raise RuntimeError("等待模型初始化超时({}s)".format(overtime))
# 获取运行的结果
responses = []
overtime = 150
had_once_forward = False
begin_time = time.time()
while True:
output = None
if had_once_forward:
begin_time = time.time()
had_once_forward = False
try:
output = self.output_queue.get(timeout=1)
except queue.Empty:
output = None
except Exception as e:
print(f"获取队列数据时发生错误: {e}")
break
if output is None:
cost_time = time.time() - begin_time
logger.info('run cost time:{}'.format(cost_time))
if cost_time >= overtime:
logger.error('It timed out, so it exited.')
break
continue
had_once_forward = True
uuids.remove(output['uuid'])
if output['status'] != 0:
logger.error('output status is error, batch_items:{}'.format(output['batch_items']))
else:
# TODO 还是先需要保存一份内容, 在挂载盘上
responses.extend(output['model_output'])
# print('run model_output:{}'.format(output['model_output']))
if len(uuids) == 0:
logger.info('All tasks have been completed.')
break
self.stop()
if len(uuids) != 0:
logger.error('not get all model result, uuids:{}'.format(uuids))
return responses
def stop(self):
self.input_queue.put(None)