Intergate async-llm to openrlfh, when tensor_parallel_size > 1 will cause ray-error

Hi, I tried to implement async-llm into openrlhf(v0.6.4). It works for 7/32b when tensor-parallel-size=1. When tensor-parallel-size > 1, it will cause ray error.

Here is the async-vllm

import socket, asyncio
import ray, random
import uvicorn, torch
from vllm import LLM, SamplingParams
from typing import Optional, Any, List, Dict, Tuple
from typing import List, Dict, Union, Any
import os, sys, uuid
from fastapi.responses import JSONResponse, Response, StreamingResponse

from openrlhf.async_pipline.process_request import GenerateRequest, process_batch_requests
import sys, os, json

from openrlhf.trainer.ray.vllm_engine import (
    LLMRayActor,
    get_all_env_variables,
    batch_vllm_engine_call,
)

import asyncio
from concurrent.futures import ThreadPoolExecutor
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

import os
import queue
from collections import defaultdict
import numpy as np
from typing import Any, List

import ray
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from vllm import LLM
from asyncio import Queue

from .utils import get_bundle_indices, ray_noset_visible_devices
from openrlhf.utils.logging_utils import init_logger
logger = init_logger(__name__)

import sys, os, json
sys.path.append(os.getenv('OPENRLHF_PATH', '/cpfs/user/chenhao/debug/OpenRLHF_0304_vllm083'))
from env.env_config import ENV_GENERATE_CONFIG

import os
env_method = os.getenv('GENERATE_METHOD', '')
GENERATE_FUNC = ENV_GENERATE_CONFIG.get(env_method, None)

import string
import random
def id_generator(size=6, chars=string.ascii_uppercase + string.digits):
    return ''.join(random.choice(chars) for _ in range(size))

logger.info({
    'ENV_METHOD': env_method,
    'GENERATE_FUNC': GENERATE_FUNC
})


def _get_free_port():
    with socket.socket() as sock:
        sock.bind(("", 0))
        return sock.getsockname()[1]



@ray.remote
class AsyncLLMRayAsyncActor:
    def __init__(self, *args, bundle_indices: list = None, **kwargs):
        noset_visible_devices = kwargs.pop("noset_visible_devices")
        if kwargs.get("distributed_executor_backend") == "ray":
            # a hack to make the script work.
            # stop ray from manipulating CUDA_VISIBLE_DEVICES
            # at the top-level when the distributed_executor_backend is ray.
            os.environ.pop("CUDA_VISIBLE_DEVICES", None)
        elif noset_visible_devices:
            # We need to set CUDA_VISIBLE_DEVICES to the ray assigned GPU
            # when the distributed_executor_backend is not ray and
            # RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is set.
            os.environ["CUDA_VISIBLE_DEVICES"] = str(ray.get_gpu_ids()[0])

        # every worker will use 0.2 GPU, so that we can schedule
        # 2 instances on the same GPUs.
        if bundle_indices is not None:
            os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.2"
            os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
            print(f"creating LLM with bundle_indices={bundle_indices}")

        # Number of actors that will send prompt to this engine
        self.num_actors = kwargs.pop("num_actors")
        self.actor_counter = 0
        self.requests = {}
        # self.responses = {}
        self.response_queues = defaultdict(queue.Queue)
        self.requests_of_ids = {}
        self.requests_labels = {}

        import vllm

        full_determinism = kwargs.pop("full_determinism", False)
        if full_determinism or vllm.__version__ == "0.8.3":
            # https://github.com/vllm-project/vllm/blob/effc5d24fae10b29996256eb7a88668ff7941aed/examples/offline_inference/reproduciblity.py#L11
            os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"

        # self.llm = LLM(*args, **kwargs)

        logger.info({
            "INFO": "##BEGIN-TO-LOAD-ASYNC-LLM##"
        })

        engine_args = vllm.AsyncEngineArgs(*args, **kwargs)
        self.async_llm = vllm.AsyncLLMEngine.from_engine_args(engine_args)

        logger.info({
            "INFO": "##SUCCEEDED-IN-LOADING-ASYNC-LLM##"
        })

        # async-server
        self.server = None
        self.server_ready = asyncio.Event()
        self.port = None
        self.address = ray._private.services.get_node_ip_address()

        self.batch_size = int(kwargs.get('batch_size', 32))

        self.worker_num = int(kwargs.get('worker_num', 8))
        self.max_queue_size = int(kwargs.get('max_queue_size', 1024))
        self.request_queue: Queue = Queue(maxsize=self.max_queue_size)
        self.pending_requests: Dict[str, asyncio.Future] = {}
        self.workers = []

        self.max_batch_size = int(kwargs.get('max_batch_size', 32))  # 单个批次最大请求数
        self.max_wait_time = float(kwargs.get('max_wait_time', 1e-1))    # 批次等待时间(秒)
        
        # 优先级队列存储元组 (priority, insertion_order, request_data)
        self.priority_queue = []
        self.queue_index = 0
        self.max_retries = 5
        self.retry_delay = 0.1
        
        self.lock = asyncio.Lock()

        asyncio.create_task(self._start_fastapi_server())

    async def start(self):
        """启动工作线程"""
        self._running = True
        for _ in range(self.worker_num):
            self.workers.append(asyncio.create_task(self._worker_loop()))
        print('==Succeeded in starting==')

    async def stop(self):
        """停止服务并清空队列"""
        self._running = False
        await self.request_queue.join()
        for worker in self.workers:
            worker.cancel()
        await asyncio.gather(*self.workers, return_exceptions=True)

    async def generate_async_server(self, prompts, sampling_params, request_id):
        # Send the request to the LLM engine.
        async with asyncio.Semaphore(128):
            stream = self.async_llm.generate(
                request_id=str(request_id),
                prompt=prompts[0],
                sampling_params=sampling_params,
            )

            # Consume the stream until the request is finished.
            async for request_output in stream:
                final_output = request_output
            output = [{
                    'outputs':[
                        {
                            "text": final_output.outputs[0].text,
                            "token_ids": final_output.outputs[0].token_ids,
                            "stop_reason": final_output.outputs[0].stop_reason,
                            "finish_reason": final_output.outputs[0].finish_reason,
                        }
                    ],
                    "prompt_token_ids": final_output.prompt_token_ids
                }]
            return output

    async def async_llm_generate(self, request: GenerateRequest):
        # 实际生成逻辑
        sampling_params = SamplingParams(
            n=request.n,
            repetition_penalty=request.repetition_penalty,
            temperature=request.temperature,
            top_p=request.top_p,
            top_k=request.top_k,
            min_p=request.min_p,
            max_tokens=request.max_tokens,
            include_stop_str_in_output=request.include_stop_str_in_output,
            stop=request.stop
        )
        response = await self.generate_async_server(request.prompts, sampling_params, id_generator(10))
        return response

    async def _worker_loop(self):
        """工作协程循环"""
        while self._running:
            try:
                request_id, request, future = await self.request_queue.get()
                
                # 实际生成逻辑
                sampling_params = SamplingParams(
                    n=request.n,
                    repetition_penalty=request.repetition_penalty,
                    temperature=request.temperature,
                    top_p=request.top_p,
                    top_k=request.top_k,
                    min_p=request.min_p,
                    max_tokens=request.max_tokens,
                    include_stop_str_in_output=request.include_stop_str_in_output,
                    stop=request.stop
                )
                
                response = await self.generate_async_server(request.prompts, sampling_params, id_generator(10))
                # response = await self.generate_async(request.prompts, sampling_params)
                
                future.set_result(response)
                self.pending_requests.pop(request_id, None)
                
            except Exception as e:
                if not future.done():
                    future.set_exception(e)
            finally:
                self.request_queue.task_done()

    async def async_generate(self, request: GenerateRequest):
        """异步生成接口"""
        # if self.request_queue.qsize() >= self.max_queue_size:
        #     raise RuntimeError("Request queue is full")

        while self.request_queue.full():
            logging.warning("Request queue is full. Waiting for space...")
            await asyncio.sleep(0.1)
        
        # 创建异步Future
        loop = asyncio.get_running_loop()
        future = loop.create_future()
        request_id = request.uuids
        
        # 将请求存入等待队列
        self.pending_requests[request_id] = future
        await self.request_queue.put((request_id, request, future))
        
        try:
            return await future
        except Exception as e:
            logger.info(f"Error in async_generate: {e}")
            if not future.done():
                future.set_exception(e)
            self.pending_requests.pop(request_id, None)
            raise
        # finally:
        #     # 确保异常时清理资源
        #     self.pending_requests.pop(request_id, None)
    
    async def _start_fastapi_server(self):
        import fastapi
        app = fastapi.FastAPI()
        app.router.add_api_route("/health", self.health, methods=["GET"])
        app.router.add_api_route("/async_generate", self.async_llm_generate, methods=["POST"])

        await asyncio.sleep(random.uniform(0, 3))
        self.port = _get_free_port()
        config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port)
        self.server = uvicorn.Server(config)  # 保存实例
        self.server_ready.set()
        await self.start()
        await self.server.serve()

    async def health(self):
        return 1

    async def restart_server(self):
        if self.server:
            await self.server.shutdown()
            await asyncio.sleep(0.5)  # 确保关闭完成
            self.server = None
        self.server_ready.clear()
        asyncio.create_task(self._start_fastapi_server())

    async def get_server_address(self):
        await self.server_ready.wait()
        return f"{self.address}:{self.port}"

    async def get_server_port(self):
        await self.server_ready.wait()
        return self.port

    def build_requests(self, prompts, prompt_ids, sampling_params, labels=None):
        request_list = []
        for idx, (prompt, prompt_id) in enumerate(zip(prompts, prompt_ids)):
            if labels is not None:
                if labels[idx] is not None:
                    label_dict = json.loads(labels[idx])
                    uuid_str = label_dict.get('uuid', str(uuid.uuid4()))
                    env_func = label_dict.get('env_func', 'math_tir_async')
                else:
                    env_func = 'math_tir_async'
                    uuid_str = str(uuid.uuid4())
                    label_dict = {
                        'uuid': uuid_str,
                        'env_func': env_func
                    }
            else:
                env_func = 'math_tir_async'
                uuid_str = str(uuid.uuid4())
                label_dict = {
                    'uuid': uuid_str,
                    'env_func': env_func
                }
                
            request = GenerateRequest(
                prompts=[prompt],
                prompt_token_ids=prompt_id,
                max_tokens=sampling_params.max_tokens,
                temperature=sampling_params.temperature,
                stop=sampling_params.stop,
                uuids=uuid_str+f'####idx:{idx}',
                env_func=env_func,
                label=json.dumps(label_dict, ensure_ascii=False)
            )
            request_list.append(request)
        return request_list

    def group_requests(self, data_list: List[Dict]):
        requests_dict = {}
        for data in data_list:
            env_func = data.env_func
            if env_func not in requests_dict:
                requests_dict[env_func] = []
            requests_dict[env_func].append(data)
        return requests_dict

    def _create_batches(self, data_list: Union[List[Dict[Any, Any]], Dict[Any, List[Any]]]) -> List[Tuple[int, List[Dict]]]:
        """将数据分成 batch,返回 [(start_idx, batch), ...]"""
        batches = []
        if isinstance(data_list, list):
            for i in range(0, len(data_list), self.batch_size):
                batch = data_list[i:i + self.batch_size]
                batches.append((i, batch))
            if i + self.batch_size < len(data_list) - 1:
                batches.append((i+1, data_list[i + self.batch_size:]))
        elif isinstance(data_list, dict):
            for env_func in data_list:
                for i in range(0, len(data_list[env_func]), self.batch_size):
                    batch = data_list[env_func][i:i + self.batch_size]
                    batches.append((i, batch))
                if i + self.batch_size < len(data_list[env_func]) - 1:
                    batches.append((i+1, data_list[env_func][i + self.batch_size:]))
        else:
            raise ValueError("data_list must be a list or dict")
        return batches

    async def add_env_requests(self, actor_rank, *, sampling_params, prompt_token_ids, 
                prompts=None, tokenizer=None, labels=None):
        """
        Save the requests from actors and generate responses when all actors have sent their requests
        """
        self.requests[actor_rank] = prompts
        self.requests_of_ids[actor_rank] = prompt_token_ids
        self.requests_labels[actor_rank] = labels
        self.actor_counter += 1
        if self.actor_counter == self.num_actors:
            assert len(self.requests) == self.num_actors
            assert len(self.requests_of_ids) == self.num_actors
            assert len(self.requests_labels) == self.num_actors
            num_requests = []
            requests = []
            requests_of_ids = []
            requests_labels = []
            for actor_rank, request in self.requests.items():
                num_requests.append((actor_rank, len(request)))
                requests.extend(request)
            for request_rank, request_ids in self.requests_of_ids.items():
                requests_of_ids.extend(request_ids)
            for request_rank, request_label in self.requests_labels.items():
                requests_labels.extend(request_label)
            
            assert len(requests_of_ids) == len(requests)
            assert len(requests_labels) == len(requests)

            ip_port = await self.get_server_address()
            url = f'http://{ip_port}/async_generate'

            logger.info({
                'IP_PORT': ip_port,
                'URL': url,
                'INFO': '##BEGIN-TO-ROLLOUT##'
            })

            if len(requests_of_ids) > 0:
                all_requests = self.build_requests(prompts=requests, prompt_ids=requests_of_ids, sampling_params=sampling_params, labels=requests_labels)
                if labels is not None:
                    all_requests = self.group_requests(all_requests)
                batches = self._create_batches(all_requests)
                responses_ray = []
                for start_idx, batch in batches:
                    env_func = batch[0].env_func
                    if env_func in ENV_GENERATE_CONFIG:
                        process_fn = ENV_GENERATE_CONFIG[env_func]
                    else:
                        process_fn = None
                    responses_ray.append(process_batch_requests.remote(url, start_idx, batch, process_fn=process_fn, tokenizer=tokenizer))
                
                # results_raw = await asyncio.gather(*responses_ray)
                # flat_results = [item for batch in results_raw for item in batch]

                # # 按 idx 排序
                # flat_results.sort(key=lambda x: x[0])
                # responses = [result[1] for result in flat_results]

                # flat_results = [item for batch in results_raw for item in batch]
                # responses = [result[1] for result in flat_results]
                # responses.sort(key=lambda x: int(x.request_id.split('####idx:')[-1]))

                results_raw = await asyncio.gather(*responses_ray)
                flat_results = []
                for result_raw in results_raw:
                    successful_results, failed_results = result_raw
                    for item in successful_results:
                        flat_results.append(item)
                responses = [result[1][1] for result in flat_results]
                responses.sort(key=lambda x: int(x.request_id.split('####idx:')[-1]))
            else:
                responses = []

            offset = 0
            self.responses = {}
            for actor_rank, num in num_requests:
                # self.responses[actor_rank] = responses[offset : offset + num]
                self.response_queues[actor_rank].put(responses[offset : offset + num])
                offset += num

            self.actor_counter = 0
            self.requests = {}
            self.requests_of_ids = {}
            self.requests_labels = {}

    def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend, use_ray):
        return self.async_llm.engine.model_executor.collective_rpc(
            "init_process_group",
            args=(master_address, master_port, rank_offset, world_size, group_name, backend, use_ray),
        )

    def update_weight(self, name, dtype, shape, empty_cache=False):
        return self.async_llm.engine.model_executor.collective_rpc(
            "update_weight", args=(name, dtype, shape, empty_cache)
        )

    def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles, empty_cache=False):
        return self.async_llm.engine.model_executor.collective_rpc(
            "update_weight_cuda_ipc", args=(name, dtype, shape, ipc_handles, empty_cache)
        )

    def reset_prefix_cache(self):
        self.async_llm.engine.reset_prefix_cache()

    def get_ckp_version(self) -> int:
        return self.ckp_version

    async def sleep(self, level: int = 1):
        await self.async_llm.sleep(level=level)

    async def wake_up(self):
        await self.async_llm.wake_up()

    def get_responses(self, actor_rank):
        """
        Return the responses for the actor with the given rank
        """
        # return self.responses.pop(actor_rank)
        return self.response_queues[actor_rank].get()


def create_async_vllm_engines(
    num_engines: int,
    tensor_parallel_size: int,
    pretrain: str,
    seed: int,
    enable_prefix_caching: bool,
    enforce_eager: bool,
    max_model_len: int,
    num_total_actors: int,
    shared_pg=None,
    gpu_memory_utilization=None,
    vllm_enable_sleep=False,
):
    import vllm

    assert vllm.__version__ >= "0.7.0", "OpenRLHF only supports vllm >= 0.7.0"

    vllm_engines = []
    num_gpus = int(tensor_parallel_size == 1)
    distributed_executor_backend = "uni" if tensor_parallel_size == 1 else "ray"
    for i in range(num_engines):
        bundle_indices = None
        scheduling_strategy = None

        # Hybrid engine
        if shared_pg is not None:
            assert vllm.__version__ >= "0.7.2", "Only vllm >= 0.7.2 supports hybrid engine"

            if tensor_parallel_size > 1:
                scheduling_strategy = PlacementGroupSchedulingStrategy(
                    placement_group=shared_pg,
                    placement_group_capture_child_tasks=True,
                    placement_group_bundle_index=i * tensor_parallel_size
                )
                bundle_indices = np.arange(i * tensor_parallel_size, (i + 1) * tensor_parallel_size).tolist()
            else:
                num_gpus = 0.2
                scheduling_strategy = PlacementGroupSchedulingStrategy(
                    placement_group=shared_pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=i
                )
        # Distributed RLHF
        elif tensor_parallel_size > 1:
            bundles = [{"GPU": 1, "CPU": 1}] * tensor_parallel_size
            pg = placement_group(bundles)
            ray.get(pg.ready())

            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=0
            )

        if num_engines >= num_total_actors:
            num_actors = 1
        else:
            num_actors = num_total_actors // num_engines + int(i < num_total_actors % num_engines)

        vllm_engines.append(
            AsyncLLMRayAsyncActor.options(
                num_cpus=0,
                num_gpus=num_gpus,
                scheduling_strategy=scheduling_strategy,
            ).remote(
                model=pretrain,
                enforce_eager=enforce_eager,
                worker_cls="openrlhf.trainer.ray.vllm_worker_wrap.WorkerWrap",
                tensor_parallel_size=tensor_parallel_size,
                # seed=seed + i,
                seed=seed,
                distributed_executor_backend=distributed_executor_backend,
                max_model_len=max_model_len,
                enable_prefix_caching=enable_prefix_caching,
                dtype="bfloat16",
                trust_remote_code=True,
                num_actors=num_actors,
                gpu_memory_utilization=gpu_memory_utilization,
                bundle_indices=bundle_indices if shared_pg else None,
                enable_sleep_mode=vllm_enable_sleep,
                noset_visible_devices=ray_noset_visible_devices(),
            )
        )
    
    if vllm_enable_sleep:
        batch_vllm_engine_call(vllm_engines, "sleep", rank_0_only=False)

    # logger.info({
    #     "INFO": "##BEGIN-TO-START-SERVER"
    # })
    # # start_server
    # batch_vllm_engine_call(vllm_engines, "init_llm_server", rank_0_only=True)
    
    # logger.info({
    #     "INFO": "##SUCCEEDED-IN-STARTING-SERVER"
    # })

    return vllm_engines


def batch_vllm_engine_call(engines: List[Any], method_name: str, *args, rank_0_only: bool = True, **kwargs):
    """
    Batch call a method on multiple vLLM engines.
    Args:
        engines: List of vLLM engine instances
        method_name: Name of the method to call
        rank_0_only: Only execute on rank 0 if True
        *args: Positional arguments to pass to the method
        **kwargs: Keyword arguments to pass to the method
    Returns:
        List of results from ray.get() if on rank 0, None otherwise
    """
    import torch

    if rank_0_only and torch.distributed.get_rank() != 0:
        return None

    refs = []
    for engine in engines:
        method = getattr(engine, method_name)
        refs.append(method.remote(*args, **kwargs))

    return ray.get(refs)

It starts a fast-api-server, and could be called via

url = f'http://{ip_port}/async_generate'

The errorlog:

Exception raised in creation task: The actor died because of an error raised in its creation task, ray::AsyncLLMRayAsyncActor.__init__() (pid=657, ip=10.39.3.79, actor_id=4c9da31eb150c91415e6687c02000000, repr=<openrlhf.trainer.ray.async_vllm_engine_async.FunctionActorManager._create_fake_actor_class.<locals>.TemporaryActor object at 0x7f70ec6d5ed0>)
(TemporaryActor pid=657, ip=10.39.3.79) ray.exceptions.ActorDiedError: The actor died unexpectedly before finishing this task.
(TemporaryActor pid=657, ip=10.39.3.79) 	class_name: TemporaryActor
(TemporaryActor pid=657, ip=10.39.3.79) 	actor_id: 4c9da31eb150c91415e6687c02000000
(TemporaryActor pid=657, ip=10.39.3.79) Failed to create actor. You set the async flag, but the actor does not have any coroutine functions.
The original cause of the RayTaskError (<class 'ray.exceptions.ActorDiedError'>) isn't serializable: cannot pickle '_struct.Struct' object. Overwriting the cause to a RayError. [repeated 5x across cluster]
(TemporaryActor pid=658, ip=10.39.3.77) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::AsyncLLMRayAsyncActor.__init__() (pid=658, ip=10.39.3.77, actor_id=03f98e9b5697bbcb7fe8017902000000, repr=<openrlhf.trainer.ray.async_vllm_engine_async.FunctionActorManager._create_fake_actor_class.<locals>.TemporaryActor object at 0x7edfe0285de0>) [repeated 5x across cluster]
(TemporaryActor pid=658, ip=10.39.3.77) ray.exceptions.ActorDiedError: The actor died unexpectedly before finishing this task. [repeated 5x across cluster]

The issue you’re encountering with tensor_parallel_size > 1 causing Ray errors is likely related to the configuration of Ray and the placement of resources. When using multiple GPUs, it’s crucial to ensure that the Ray cluster is correctly set up to handle the distribution of tasks across the GPUs. Here are some steps and considerations based on the provided data sources:

  1. Ray Initialization: Ensure that Ray is properly initialized and connected to the cluster. If you’re using a multi-node setup, make sure that the Ray head node is correctly set up and that all worker nodes are connected. You can check the status of the Ray cluster using ray status.

  2. Placement Groups: Use Ray’s placement groups to manage resource allocation. Placement groups allow you to specify how resources should be allocated across nodes. For example, you can use the PlacementGroupSchedulingStrategy to ensure that your actors are placed on nodes with the required resources. This is particularly important when using tensor_parallel_size > 1 as it requires multiple GPUs.

  3. Environment Variables: Set the necessary environment variables to control GPU allocation. For instance, VLLM_RAY_PER_WORKER_GPUS and VLLM_RAY_BUNDLE_INDICES can be used to specify how GPUs are allocated to each worker. Ensure that these are set correctly to match your cluster’s configuration.

  4. Error Handling: The error message indicates that the actor died unexpectedly. This could be due to resource allocation issues or other configuration problems. Check the logs for more detailed error messages and ensure that all dependencies and configurations are correct.

  5. Ray Version: Ensure that you are using a compatible version of Ray with vLLM. Some issues might arise from version incompatibilities.

  6. Debugging: Use Ray’s debugging tools to trace the issue. You can enable detailed logging and use Ray’s dashboard to monitor the status of your actors and resources.

If you need more detailed guidance on any of these steps, please let me know.

Sources: