Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Engine is gracefully shutting down #11873

Open
1 task done
Bryce1010 opened this issue Jan 9, 2025 · 7 comments
Open
1 task done

[Bug]: Engine is gracefully shutting down #11873

Bryce1010 opened this issue Jan 9, 2025 · 7 comments
Labels
bug Something isn't working

Comments

@Bryce1010
Copy link
Contributor

Bryce1010 commented Jan 9, 2025

Your current environment

The output of `python collect_env.py`
Your output of `python collect_env.py` here

Model Input Dumps

No response

🐛 Describe the bug

my goal is to create an asynchronous engine and keep it running. I want to continuously add requests to the engine and retrieve their outputs. At the same time, I need the ability to abort certain requests during generation if they no longer require additional tokens.

generate

catch exception after final call

import asyncio

from vllm.engine.async_llm_engine import (
    AsyncLLMEngine,
    AsyncEngineArgs,
    SamplingParams,
)

async def generate_text(engine, prompt: str, request_id: str):
    sampling_params = SamplingParams(
        temperature=0.0,
        max_tokens=16,
    )
    final_output = None
    async for output in engine.generate(prompt, sampling_params, request_id=request_id):
        final_output = output
    return final_output

async def main():

    engine = AsyncLLMEngine.from_engine_args(
        AsyncEngineArgs(model="facebook/opt-125m") 
    )

    print("===> Start first async inference")
    output1 = await generate_text(engine, "Test first call.", request_id="1")
    print(f"[First result] {output1}")

    print("===> Call async inference again (expect graceful shutdown error)")
    output2 = await generate_text(engine, "Test second call.", request_id="2")
    print(f"[Second result] {output2}")

    print("===> Start third async inference")
    output3 = await generate_text(engine, "Test third call.", request_id="3")
    print(f"[Third result] {output3}")

if __name__ == "__main__":
    try:
        asyncio.run(main())
    except Exception as e:
        print("Catch exception:", e)
$ python examples/offline_inference_asnyc.py
DEBUG 01-09 11:33:06 __init__.py:26] No plugins for group vllm.platform_plugins found.
INFO 01-09 11:33:07 __init__.py:179] Automatically detected platform cuda.
DEBUG 01-09 11:33:07 __init__.py:26] No plugins for group vllm.general_plugins found.
INFO 01-09 11:33:14 config.py:517] This model supports multiple tasks: {'generate', 'reward', 'embed', 'score', 'classify'}. Defaulting to 'generate'.
INFO 01-09 11:33:14 llm_engine.py:234] Initializing an LLM engine (v0.6.6.post2.dev91+gf70bbb36.d20250106) with config: model='facebook/opt-125m', speculative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=facebook/opt-125m, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"candidate_compile_sizes":[],"compile_sizes":[],"capture_sizes":[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":256}, use_cached_outputs=False, 
INFO 01-09 11:33:15 selector.py:120] Using Flash Attention backend.
DEBUG 01-09 11:33:15 parallel_state.py:952] world_size=1 rank=0 local_rank=0 distributed_init_method=tcp://11.166.9.235:34745 backend=nccl
INFO 01-09 11:33:15 model_runner.py:1144] Starting to load model facebook/opt-125m...
DEBUG 01-09 11:33:15 decorators.py:105] Inferred dynamic dimensions for forward method of <class 'vllm.model_executor.models.opt.OPTModel'>: ['input_ids', 'positions', 'intermediate_tensors', 'inputs_embeds']
DEBUG 01-09 11:33:15 config.py:3328] enabled custom ops: Counter()
DEBUG 01-09 11:33:15 config.py:3330] disabled custom ops: Counter()
INFO 01-09 11:33:16 weight_utils.py:251] Using model weights format ['*.bin']
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
/root/PublicProject/vllm/vllm/model_executor/model_loader/weight_utils.py:450: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state = torch.load(bin_file, map_location="cpu")
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  8.15it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  8.13it/s]

INFO 01-09 11:33:16 model_runner.py:1149] Loading model weights took 0.2389 GB
INFO 01-09 11:33:17 worker.py:241] Memory profiling takes 0.34 seconds
INFO 01-09 11:33:17 worker.py:241] the current vLLM instance can use total_gpu_memory (44.42GiB) x gpu_memory_utilization (0.90) = 39.98GiB
INFO 01-09 11:33:17 worker.py:241] model weights take 0.24GiB; non_torch_memory takes 0.08GiB; PyTorch activation peak memory takes 0.47GiB; the rest of the memory reserved for KV Cache is 39.19GiB.
INFO 01-09 11:33:17 gpu_executor.py:76] # GPU blocks: 71335, # CPU blocks: 7281
INFO 01-09 11:33:17 gpu_executor.py:80] Maximum concurrency for 2048 tokens per request: 557.30x
INFO 01-09 11:33:18 model_runner.py:1466] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
Capturing CUDA graph shapes: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:07<00:00,  4.88it/s]
INFO 01-09 11:33:25 model_runner.py:1591] Graph capturing finished in 7 secs, took 0.12 GiB
INFO 01-09 11:33:25 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 9.14 seconds
===> Start first async inference
INFO 01-09 11:33:26 async_llm_engine.py:969] self.is_running: False, self.start_engine_loop: True
INFO 01-09 11:33:26 async_llm_engine.py:211] Added request 1.
DEBUG 01-09 11:33:26 async_llm_engine.py:859] Waiting for new requests...
DEBUG 01-09 11:33:26 async_llm_engine.py:878] Got new requests!
INFO 01-09 11:33:26 async_llm_engine.py:179] Finished request 1.
[First result] RequestOutput(request_id=1, prompt='Test first call.', prompt_token_ids=[2, 34603, 78, 486, 4], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='                ', token_ids=(1437, 1437, 1437, 1437, 1437, 1437, 1437, 1437, 1437, 1437, 1437, 1437, 1437, 1437, 1437, 1437), cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1736393606.5197597, last_token_time=1736393606.5460815, first_scheduled_time=1736393606.5213075, first_token_time=1736393606.5267124, time_in_queue=0.0015478134155273438, finished_time=1736393606.546158, scheduler_time=0.0008088539470918477, model_forward_time=None, model_execute_time=None), lora_request=None, num_cached_tokens=0, multi_modal_placeholders={})
===> Call async inference again (expect graceful shutdown error)
INFO 01-09 11:33:26 async_llm_engine.py:969] self.is_running: True, self.start_engine_loop: True
INFO 01-09 11:33:26 async_llm_engine.py:211] Added request 2.
DEBUG 01-09 11:33:26 async_llm_engine.py:859] Waiting for new requests...
DEBUG 01-09 11:33:26 async_llm_engine.py:878] Got new requests!
INFO 01-09 11:33:26 async_llm_engine.py:179] Finished request 2.
[Second result] RequestOutput(request_id=2, prompt='Test second call.', prompt_token_ids=[2, 34603, 200, 486, 4], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text="\nI'm not sure if I should be worried about the second call or not", token_ids=(50118, 100, 437, 45, 686, 114, 38, 197, 28, 3915, 59, 5, 200, 486, 50, 45), cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1736393606.547024, last_token_time=1736393606.569789, first_scheduled_time=1736393606.547554, first_token_time=1736393606.5508611, time_in_queue=0.0005300045013427734, finished_time=1736393606.5698507, scheduler_time=0.0006494068657048047, model_forward_time=None, model_execute_time=None), lora_request=None, num_cached_tokens=0, multi_modal_placeholders={})
===> Start third async inference
INFO 01-09 11:33:26 async_llm_engine.py:969] self.is_running: True, self.start_engine_loop: True
INFO 01-09 11:33:26 async_llm_engine.py:211] Added request 3.
DEBUG 01-09 11:33:26 async_llm_engine.py:859] Waiting for new requests...
DEBUG 01-09 11:33:26 async_llm_engine.py:878] Got new requests!
INFO 01-09 11:33:26 async_llm_engine.py:179] Finished request 3.
[Third result] RequestOutput(request_id=3, prompt='Test third call.', prompt_token_ids=[2, 34603, 371, 486, 4], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text="\nI'm not sure if I'm the only one who thinks this is a", token_ids=(50118, 100, 437, 45, 686, 114, 38, 437, 5, 129, 65, 54, 4265, 42, 16, 10), cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1736393606.5706193, last_token_time=1736393606.5931807, first_scheduled_time=1736393606.571072, first_token_time=1736393606.5743005, time_in_queue=0.0004527568817138672, finished_time=1736393606.5932353, scheduler_time=0.0005872009787708521, model_forward_time=None, model_execute_time=None), lora_request=None, num_cached_tokens=0, multi_modal_placeholders={})
INFO 01-09 11:33:26 async_llm_engine.py:65] Engine is gracefully shutting down
INFO 01-09 11:33:26 async_llm_engine.py:65] Traceback (most recent call last):
INFO 01-09 11:33:26 async_llm_engine.py:65]   File "/root/PublicProject/vllm/vllm/engine/async_llm_engine.py", line 58, in _log_task_completion
INFO 01-09 11:33:26 async_llm_engine.py:65]     return_value = task.result()
INFO 01-09 11:33:26 async_llm_engine.py:65]                    ^^^^^^^^^^^^^
INFO 01-09 11:33:26 async_llm_engine.py:65]   File "/root/PublicProject/vllm/vllm/engine/async_llm_engine.py", line 889, in run_engine_loop
INFO 01-09 11:33:26 async_llm_engine.py:65]     done, _ = await asyncio.wait(
INFO 01-09 11:33:26 async_llm_engine.py:65]               ^^^^^^^^^^^^^^^^^^^
INFO 01-09 11:33:26 async_llm_engine.py:65]   File "/root/miniforge3/envs/vllm-dev/lib/python3.12/asyncio/tasks.py", line 464, in wait
INFO 01-09 11:33:26 async_llm_engine.py:65]     return await _wait(fs, timeout, return_when, loop)
INFO 01-09 11:33:26 async_llm_engine.py:65]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
INFO 01-09 11:33:26 async_llm_engine.py:65]   File "/root/miniforge3/envs/vllm-dev/lib/python3.12/asyncio/tasks.py", line 550, in _wait
INFO 01-09 11:33:26 async_llm_engine.py:65]     await waiter
INFO 01-09 11:33:26 async_llm_engine.py:65] asyncio.exceptions.CancelledError
[rank0]:[W109 11:33:26.422312123 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())

call with abort

catch exception after first call

import asyncio

from vllm.engine.async_llm_engine import (
    AsyncLLMEngine,
    AsyncEngineArgs,
    SamplingParams,
)

async def generate_text(engine, prompt: str, request_id: str):
    sampling_params = SamplingParams(
        temperature=0.0,
        max_tokens=16,
    )
    final_output = None
    async for output in engine.generate(prompt, sampling_params, request_id=request_id):
        final_output = output
        print(f"request_id: {request_id}, output: {output}, abort")
        await engine.abort(request_id=request_id)
    return final_output

async def main():

    engine = AsyncLLMEngine.from_engine_args(
        AsyncEngineArgs(model="facebook/opt-125m") 
    )

    print("===> Start first async inference")
    output1 = await generate_text(engine, "Test first call.", request_id="1")
    print(f"[First result] {output1}")

    print("===> Call async inference again (expect graceful shutdown error)")
    output2 = await generate_text(engine, "Test second call.", request_id="2")
    print(f"[Second result] {output2}")

    print("===> Start third async inference")
    output3 = await generate_text(engine, "Test third call.", request_id="3")
    print(f"[Third result] {output3}")

if __name__ == "__main__":
    try:
        asyncio.run(main())
    except Exception as e:
        print("Catch exception:", e)
$ python examples/offline_inference_asnyc.py
DEBUG 01-09 11:32:33 __init__.py:26] No plugins for group vllm.platform_plugins found.
INFO 01-09 11:32:33 __init__.py:179] Automatically detected platform cuda.
DEBUG 01-09 11:32:33 __init__.py:26] No plugins for group vllm.general_plugins found.
INFO 01-09 11:32:40 config.py:517] This model supports multiple tasks: {'embed', 'score', 'reward', 'generate', 'classify'}. Defaulting to 'generate'.
INFO 01-09 11:32:40 llm_engine.py:234] Initializing an LLM engine (v0.6.6.post2.dev91+gf70bbb36.d20250106) with config: model='facebook/opt-125m', speculative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=facebook/opt-125m, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"candidate_compile_sizes":[],"compile_sizes":[],"capture_sizes":[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":256}, use_cached_outputs=False, 
INFO 01-09 11:32:41 selector.py:120] Using Flash Attention backend.
DEBUG 01-09 11:32:41 parallel_state.py:952] world_size=1 rank=0 local_rank=0 distributed_init_method=tcp://11.166.9.235:46077 backend=nccl
INFO 01-09 11:32:41 model_runner.py:1144] Starting to load model facebook/opt-125m...
DEBUG 01-09 11:32:41 decorators.py:105] Inferred dynamic dimensions for forward method of <class 'vllm.model_executor.models.opt.OPTModel'>: ['input_ids', 'positions', 'intermediate_tensors', 'inputs_embeds']
DEBUG 01-09 11:32:41 config.py:3328] enabled custom ops: Counter()
DEBUG 01-09 11:32:41 config.py:3330] disabled custom ops: Counter()
INFO 01-09 11:32:42 weight_utils.py:251] Using model weights format ['*.bin']
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
/root/PublicProject/vllm/vllm/model_executor/model_loader/weight_utils.py:450: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state = torch.load(bin_file, map_location="cpu")
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  7.94it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  7.93it/s]

INFO 01-09 11:32:42 model_runner.py:1149] Loading model weights took 0.2389 GB
INFO 01-09 11:32:43 worker.py:241] Memory profiling takes 0.34 seconds
INFO 01-09 11:32:43 worker.py:241] the current vLLM instance can use total_gpu_memory (44.42GiB) x gpu_memory_utilization (0.90) = 39.98GiB
INFO 01-09 11:32:43 worker.py:241] model weights take 0.24GiB; non_torch_memory takes 0.08GiB; PyTorch activation peak memory takes 0.47GiB; the rest of the memory reserved for KV Cache is 39.19GiB.
INFO 01-09 11:32:43 gpu_executor.py:76] # GPU blocks: 71335, # CPU blocks: 7281
INFO 01-09 11:32:43 gpu_executor.py:80] Maximum concurrency for 2048 tokens per request: 557.30x
INFO 01-09 11:32:44 model_runner.py:1466] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
Capturing CUDA graph shapes: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:07<00:00,  4.64it/s]
INFO 01-09 11:32:52 model_runner.py:1591] Graph capturing finished in 8 secs, took 0.12 GiB
INFO 01-09 11:32:52 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 9.54 seconds
===> Start first async inference
INFO 01-09 11:32:52 async_llm_engine.py:969] self.is_running: False, self.start_engine_loop: True
INFO 01-09 11:32:52 async_llm_engine.py:211] Added request 1.
DEBUG 01-09 11:32:52 async_llm_engine.py:859] Waiting for new requests...
DEBUG 01-09 11:32:52 async_llm_engine.py:878] Got new requests!
INFO 01-09 11:32:52 async_llm_engine.py:223] Aborted request 1.
INFO 01-09 11:32:52 async_llm_engine.py:223] Aborted request 1.
INFO 01-09 11:32:52 async_llm_engine.py:65] Engine is gracefully shutting down
INFO 01-09 11:32:52 async_llm_engine.py:65] Traceback (most recent call last):
INFO 01-09 11:32:52 async_llm_engine.py:65]   File "/root/miniforge3/envs/vllm-dev/lib/python3.12/asyncio/runners.py", line 194, in run
INFO 01-09 11:32:52 async_llm_engine.py:65]     return runner.run(main)
INFO 01-09 11:32:52 async_llm_engine.py:65]            ^^^^^^^^^^^^^^^^
INFO 01-09 11:32:52 async_llm_engine.py:65]   File "/root/miniforge3/envs/vllm-dev/lib/python3.12/asyncio/runners.py", line 118, in run
INFO 01-09 11:32:52 async_llm_engine.py:65]     return self._loop.run_until_complete(task)
INFO 01-09 11:32:52 async_llm_engine.py:65]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
INFO 01-09 11:32:52 async_llm_engine.py:65]   File "/root/miniforge3/envs/vllm-dev/lib/python3.12/asyncio/base_events.py", line 686, in run_until_complete
INFO 01-09 11:32:52 async_llm_engine.py:65]     return future.result()
INFO 01-09 11:32:52 async_llm_engine.py:65]            ^^^^^^^^^^^^^^^
INFO 01-09 11:32:52 async_llm_engine.py:65]   File "/root/PublicProject/vllm/examples/offline_inference_asnyc.py", line 28, in main
INFO 01-09 11:32:52 async_llm_engine.py:65]     output1 = await generate_text(engine, "Test first call.", request_id="1")
INFO 01-09 11:32:52 async_llm_engine.py:65]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
INFO 01-09 11:32:52 async_llm_engine.py:65]   File "/root/PublicProject/vllm/examples/offline_inference_asnyc.py", line 15, in generate_text
INFO 01-09 11:32:52 async_llm_engine.py:65]     async for output in engine.generate(prompt, sampling_params, request_id=request_id):
INFO 01-09 11:32:52 async_llm_engine.py:65]   File "/root/PublicProject/vllm/vllm/engine/async_llm_engine.py", line 1076, in generate
INFO 01-09 11:32:52 async_llm_engine.py:65]     async for output in await self.add_request(
INFO 01-09 11:32:52 async_llm_engine.py:65]   File "/root/PublicProject/vllm/vllm/engine/async_llm_engine.py", line 116, in generator
INFO 01-09 11:32:52 async_llm_engine.py:65]     raise result
INFO 01-09 11:32:52 async_llm_engine.py:65] asyncio.exceptions.CancelledError
INFO 01-09 11:32:52 async_llm_engine.py:65] 
INFO 01-09 11:32:52 async_llm_engine.py:65] During handling of the above exception, another exception occurred:
INFO 01-09 11:32:52 async_llm_engine.py:65] 
INFO 01-09 11:32:52 async_llm_engine.py:65] Traceback (most recent call last):
INFO 01-09 11:32:52 async_llm_engine.py:65]   File "/root/PublicProject/vllm/vllm/engine/async_llm_engine.py", line 58, in _log_task_completion
INFO 01-09 11:32:52 async_llm_engine.py:65]     return_value = task.result()
INFO 01-09 11:32:52 async_llm_engine.py:65]                    ^^^^^^^^^^^^^
INFO 01-09 11:32:52 async_llm_engine.py:65]   File "/root/PublicProject/vllm/vllm/engine/async_llm_engine.py", line 889, in run_engine_loop
INFO 01-09 11:32:52 async_llm_engine.py:65]     done, _ = await asyncio.wait(
INFO 01-09 11:32:52 async_llm_engine.py:65]               ^^^^^^^^^^^^^^^^^^^
INFO 01-09 11:32:52 async_llm_engine.py:65]   File "/root/miniforge3/envs/vllm-dev/lib/python3.12/asyncio/tasks.py", line 464, in wait
INFO 01-09 11:32:52 async_llm_engine.py:65]     return await _wait(fs, timeout, return_when, loop)
INFO 01-09 11:32:52 async_llm_engine.py:65]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
INFO 01-09 11:32:52 async_llm_engine.py:65]   File "/root/miniforge3/envs/vllm-dev/lib/python3.12/asyncio/tasks.py", line 550, in _wait
INFO 01-09 11:32:52 async_llm_engine.py:65]     await waiter
INFO 01-09 11:32:52 async_llm_engine.py:65] asyncio.exceptions.CancelledError
[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/PublicProject/vllm/examples/offline_inference_asnyc.py", line 41, in <module>
[rank0]:     asyncio.run(main())
[rank0]:   File "/root/miniforge3/envs/vllm-dev/lib/python3.12/asyncio/runners.py", line 194, in run
[rank0]:     return runner.run(main)
[rank0]:            ^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniforge3/envs/vllm-dev/lib/python3.12/asyncio/runners.py", line 118, in run
[rank0]:     return self._loop.run_until_complete(task)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniforge3/envs/vllm-dev/lib/python3.12/asyncio/base_events.py", line 686, in run_until_complete
[rank0]:     return future.result()
[rank0]:            ^^^^^^^^^^^^^^^
[rank0]:   File "/root/PublicProject/vllm/examples/offline_inference_asnyc.py", line 28, in main
[rank0]:     output1 = await generate_text(engine, "Test first call.", request_id="1")
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/PublicProject/vllm/examples/offline_inference_asnyc.py", line 15, in generate_text
[rank0]:     async for output in engine.generate(prompt, sampling_params, request_id=request_id):
[rank0]:   File "/root/PublicProject/vllm/vllm/engine/async_llm_engine.py", line 1076, in generate
[rank0]:     async for output in await self.add_request(
[rank0]:   File "/root/PublicProject/vllm/vllm/engine/async_llm_engine.py", line 116, in generator
[rank0]:     raise result
[rank0]: asyncio.exceptions.CancelledError
[rank0]:[W109 11:32:53.735763655 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@Bryce1010 Bryce1010 added the bug Something isn't working label Jan 9, 2025
@Hyunnicolou
Copy link

The issue occurs because the AsyncLLMEngine was not properly managing its lifecycle across multiple calls. The engine’s background loop (run_engine_loop) is mistakenly stopped after the first request, causing an asyncio.exceptions.CancelledError when the second request is sent.

@Hyunnicolou
Copy link

The AsyncLLMEngine is designed for long-running services. Once started, it should keep processing requests until manually shut down. If you need to handle multiple requests in one session, the engine should stay active instead of being repeatedly stopped and restarted.

@Bryce1010
Copy link
Contributor Author

Bryce1010 commented Jan 9, 2025

@Hyunnicolou, thank you for your answer. You described exactly the issue I’m encountering. However, do you think this is a bug, or is it simply a misuse of the feature?”

“(Updated) By the way, I also tested the abort function, but exceptions were thrown on the first call.”
CC: @DarkLight1337 @youkaichao @WoosukKwon

@lyblsgo
Copy link

lyblsgo commented Jan 9, 2025

The AsyncLLMEngine is designed for long-running services. Once started, it should keep processing requests until manually shut down. If you need to handle multiple requests in one session, the engine should stay active instead of being repeatedly stopped and restarted.

I encountered the same problem. How to set the engine stay active

@Hyunnicolou
Copy link

If you want to restart the engine for each request, make sure to stop and restart the engine properly between calls:
'''
async def main():
for i in range(2):
print(f"===> Start async inference {i+1}")
engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(model="facebook/opt-125m"))
try:
output = await generate_text(engine, f"Test call {i+1}.", request_id=str(i+1))
print(f"[Result {i+1}] {output}")
finally:
engine.shutdown_background_loop()
'''

@Bryce1010
Copy link
Contributor Author

@Hyunnicolou
my goal is to create an asynchronous engine and keep it running. I want to continuously add requests to the engine and retrieve their outputs. At the same time, I need the ability to abort certain requests during generation if they no longer require additional tokens.”, my goal is to create an asynchronous engine and keep it running. I want to continuously add requests to the engine and retrieve their outputs. At the same time, I need the ability to abort certain requests during generation if they no longer require additional tokens.

@fxmarty-amd
Copy link

fxmarty-amd commented Jan 15, 2025

This looks to be duplicate of #11603

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants