About clean_up in the example of vllm-ascend official website

现在vllm-ascend官网示例中的clean,只能清理slave,清理不掉driver worker,有什么办法可以把driver worker也清理掉吗


Now the clean in the example of vllm-ascend official website can only clean up the slave, but not the driver worker. Is there any way to clean up the driver worker as well?

import gc

import torch

from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import (destroy_distributed_environment,
                                             destroy_model_parallel)

def clean_up():
    destroy_model_parallel()
    destroy_distributed_environment()
    gc.collect()
    torch.npu.empty_cache()

prompts = [
    "Hello, my name is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.6, top_p=0.95, top_k=40)
llm = LLM(model="Qwen/QwQ-32B",
          tensor_parallel_size=4,
          distributed_executor_backend="mp",
          max_model_len=4096)

outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

del llm
clean_up()

The original question is, when the users use “mp” as distributed_executor_backend, and call multiple LLM instances creation and inferences sequentially, the driver process won’t be killed thus all the resourse won’t be cleaned.

BTW, this user called llm.llm_engine.model_executor.shutdown() by hand, that’s why the slave processes are cleaned up.

the reproduction demo:

import gc, os
import torch
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import (destroy_distributed_environment,
                                             destroy_model_parallel)
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
prompts = [
    "Hello, my name is",
    "The future of AI is",
]
def clean_up():
    destroy_model_parallel()
    destroy_distributed_environment()
    gc.collect()
    torch.npu.empty_cache()

def gen():
    # Create a sampling params object.
    sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
    # Create an LLM.
    llm = LLM(
        model="Qwen/Qwen2.5-0.5B-Instruct",
        tensor_parallel_size=2,
        distributed_executor_backend="mp",
        trust_remote_code=True,
    )

    # Generate texts from the prompts.
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

    llm.llm_engine.model_executor.shutdown()

    del llm
    clean_up()


if __name__ == "__main__":
    gen()
    gen()

If you want to use “mp” as distributed_executor_backend and meet the needs of using multiple llm instances in sequence, we could execute vllm inferencing in a subprocess, like this:

def gen():
    # Create a sampling params object.
    sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
    # Create an LLM.
    llm = LLM(
        model="Qwen/Qwen2.5-0.5B-Instruct",
        tensor_parallel_size=2,
        distributed_executor_backend="mp",
        trust_remote_code=True,
    )

    # Generate texts from the prompts.
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == "__main__":
    from multiprocessing import Process

    proc = Process(target=gen)
    proc.start()
    proc.join()
    proc.terminate()

    proc = Process(target=gen)
    proc.start()
    proc.join()
    proc.terminate()
1 Like

Actually this user attempt to run vllm as a rollout, which need load new weights at each step. I recommand to use external_launcher instead of mp, because it is a kind of spmd executor vllm already provided.

1 Like