-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvllm_falcon_7b.py
139 lines (122 loc) · 5.77 KB
/
vllm_falcon_7b.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Example of a vLLM prompt completion service based on the Falcon-7b LLM
to get deployed on Ray Serve.
Adapted from the AnyScale team's repository
https://github.com/ray-project/ray/blob\
/cc983fc3e64c1ba215e981a43dd0119c03c74ff1/doc/source/serve/doc_code/vllm_example.py
"""
import json
from typing import AsyncGenerator
from fastapi import BackgroundTasks
from starlette.requests import Request
from starlette.responses import StreamingResponse, Response
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from ray import serve
@serve.deployment(ray_actor_options={"num_gpus": 1})
class VLLMPredictDeployment:
def __init__(self, **kwargs):
"""
Construct a VLLM deployment.
Refer to https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py
for the full list of arguments.
Args:
model: name or path of the huggingface model to use
download_dir: directory to download and load the weights,
default to the default cache dir of huggingface.
use_np_weights: save a numpy copy of model weights for
faster loading. This can increase the disk usage by up to 2x.
use_dummy_weights: use dummy values for model weights.
dtype: data type for model weights and activations.
The "auto" option will use FP16 precision
for FP32 and FP16 models, and BF16 precision.
for BF16 models.
seed: random seed.
worker_use_ray: use Ray for distributed serving, will be
automatically set when using more than 1 GPU
pipeline_parallel_size: number of pipeline stages.
tensor_parallel_size: number of tensor parallel replicas.
block_size: token block size.
swap_space: CPU swap space size (GiB) per GPU.
gpu_memory_utilization: the percentage of GPU memory to be used for
the model executor
max_num_batched_tokens: maximum number of batched tokens per iteration
max_num_seqs: maximum number of sequences per iteration.
disable_log_stats: disable logging statistics.
engine_use_ray: use Ray to start the LLM engine in a separate
process as the server process.
disable_log_requests: disable logging requests.
"""
args = AsyncEngineArgs(**kwargs)
self.engine = AsyncLLMEngine.from_engine_args(args)
async def stream_results(self, results_generator) -> AsyncGenerator[bytes, None]:
num_returned = 0
async for request_output in results_generator:
text_outputs = [output.text for output in request_output.outputs]
assert len(text_outputs) == 1
text_output = text_outputs[0][num_returned:]
ret = {"text": text_output}
yield (json.dumps(ret) + "\n").encode("utf-8")
num_returned += len(text_output)
async def may_abort_request(self, request_id) -> None:
await self.engine.abort(request_id)
async def __call__(self, request: Request) -> Response:
"""Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
results_generator = self.engine.generate(prompt, sampling_params, request_id)
if stream:
background_tasks = BackgroundTasks()
# Using background_taks to abort the the request
# if the client disconnects.
background_tasks.add_task(self.may_abort_request, request_id)
return StreamingResponse(
self.stream_results(results_generator), background=background_tasks
)
# Non-streaming case
final_output = None
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
return Response(status_code=499)
final_output = request_output
assert final_output is not None
prompt = final_output.prompt
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
return Response(content=json.dumps(ret))
def send_sample_request():
"""
An example of how to send a prompt completion request to the Ray head URL.
The completion gets printed to the std output.
:return: None
"""
import requests
import json
prompt = "How do I cook fried rice?"
sample_input = {"prompt": prompt,
"stream": False,
"max_tokens": 128,
"temperature": 0,
}
# Replace the hostname with Ray head's hostname or IP address
ray_url = "http://localhost:8000/"
output = requests.post(ray_url, json=sample_input)
for line in output.iter_lines():
print(json.loads(line.decode("utf-8"))['text'][0])
# Deployment definition for Ray Serve
deployment = VLLMPredictDeployment.bind(model="tiiuae/falcon-7b-instruct",
dtype="bfloat16",
trust_remote_code=True,
)