diff --git a/llm/fastdeploy_llm/Client.py b/llm/fastdeploy_llm/Client.py new file mode 100644 index 0000000000..53965365b3 --- /dev/null +++ b/llm/fastdeploy_llm/Client.py @@ -0,0 +1,164 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import queue +import json +import sys +from functools import partial +import os +import time +import numpy as np +import subprocess +import tritonclient.grpc as grpcclient +from tritonclient.utils import * + +import api_client + + +class UserData: + + def __init__(self): + self._completed_requests = queue.Queue() + + +def callback(user_data, result, error): + if error: + user_data._completed_requests.put(error) + else: + user_data._completed_requests.put(result) + + +class grpcClient: + + def __init__(self, + base_url: str, + model_name: str, + model_version: str = "1", + timeout: int = 100, + openai_port: int = None): + """ + Args: + base_url (`str`): inference server grpc url + model_name (`str`) + model_version (`str`): default "1" + timeout (`int`): inference timeout in seconds + openai_port (`int`) + """ + self._model_name = model_name + self._model_version = model_version + self.timeout = timeout + self._client = grpcclient.InferenceServerClient(base_url, + verbose=False) + + error = self._verify_triton_state(self._client) + if error: + raise RuntimeError( + f"Could not communicate to Triton Server: {error}") + + self.inputs = [ + grpcclient.InferInput("IN", [1], np_to_triton_dtype(np.object_)) + ] + self.outputs = [grpcclient.InferRequestedOutput("OUT")] + self.has_init = False + self.user_data = UserData() + + if openai_port is not None: + pd_cmd = "python3 api_client.py --url {0} --port {1} --model {2}".format( + base_url, openai_port, model_name) + subprocess.Popen(pd_cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + preexec_fn=os.setsid) + time.sleep(5) + + def _verify_triton_state(self, triton_client): + if not triton_client.is_server_live(): + return f"Triton server {self._server_url} is not live" + elif not triton_client.is_server_ready(): + return f"Triton server {self._server_url} is not ready" + elif not triton_client.is_model_ready(self._model_name, + self._model_version): + return f"Model {self._model_name}:{self._model_version} is not ready" + return None + + def generate(self, + prompt: str, + request_id: str = "0", + top_p: float = 0.0, + temperature: float = 1.0, + max_dec_len: int = 1024, + min_dec_len: int = 2, + penalty_score: float = 1.0, + frequency_score: float = 0.99, + eos_token_id: int = 2, + presence_score: float = 0.0, + stream: bool = False): + + req_dict = { + "text": prompt, + "topp": top_p, + "temperature": temperature, + "max_dec_len": max_dec_len, + "min_dec_len": min_dec_len, + "penalty_score": penalty_score, + "frequency_score": frequency_score, + "eos_token_id": eos_token_id, + "model_test": "test", + "presence_score": presence_score + } + + try: + if not self.has_init: + self._client.start_stream( + callback=partial(callback, self.user_data)) + self.has_init = True + else: + self.user_data.reset() + self.inputs = [ + grpcclient.InferInput("IN", [1], + np_to_triton_dtype(np.object_)) + ] + self.outputs = [grpcclient.InferRequestedOutput("OUT")] + + in_data = np.array([json.dumps(req_dict)], dtype=np.object_) + self.inputs[0].set_data_from_numpy(in_data) + + self._client.async_stream_infer(model_name=self._model_name, + inputs=self.inputs, + request_id=request_id, + outputs=self.outputs) + if stream: + completion = [] + else: + completion = "" + while True: + data_item = self.user_data._completed_requests.get( + timeout=self.timeout) + if type(data_item) == InferenceServerException: + print('Exception:', 'status', data_item.status(), 'msg', + data_item.message()) + else: + results = data_item.as_numpy("OUT")[0] + data = json.loads(results) + if stream: + completion.append(data["result"]) + else: + completion += data["result"] + if data.get("is_end", False): + break + return completion + except Exception as e: + print(f"Client infer error: {e}") + raise e diff --git a/llm/fastdeploy_llm/__init__.py b/llm/fastdeploy_llm/__init__.py index 4fe63367d9..6bcc2b0e43 100644 --- a/llm/fastdeploy_llm/__init__.py +++ b/llm/fastdeploy_llm/__init__.py @@ -17,3 +17,4 @@ from .task import Task, BatchTask from .config import Config from . import utils +from . import Client diff --git a/llm/fastdeploy_llm/api_client.py b/llm/fastdeploy_llm/api_client.py new file mode 100644 index 0000000000..718a06b49a --- /dev/null +++ b/llm/fastdeploy_llm/api_client.py @@ -0,0 +1,331 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import numpy as np +import argparse + +import time +import random +from http import HTTPStatus +import tornado +from tornado import web +from tornado.concurrent import run_on_executor +from concurrent.futures import ThreadPoolExecutor + +from utils.conversation import * +from Client import * + +parse = argparse.ArgumentParser() +parse.add_argument('--url', type=str, help='grpc server url') +parse.add_argument('--port', type=int, help='openai http port', default=2001) +parse.add_argument('--model', type=str, help='model name', default="model") + + +def parse_parameters(parameters_config, name, default_value): + if name not in parameters_config: + return default_value + return parameters_config[name] + + +def create_error_response(status_code, msg): + output = { + "status": status_code, + "errResponse": { + "message": msg, + "type": "invalid_request_error" + } + } + return output + + +class ChatCompletionApiHandler(web.RequestHandler): + """ + This handler provides OpenAI's ChatCompletion API。 + + See https://platform.openai.com/docs/api-reference/chat/create + for the API specification. This API mimics the OpenAI ChatCompletion API. + + NOTE: Currently we do not support the following features: + - n (currently only support 1) + - logit_bias + - logprobs + - stop (currently support token id) + - function_call (Users should implement this by themselves) + - function (Users should implement this by themselves) + """ + executor = ThreadPoolExecutor(20) + + def __init__(self, application, request, **kwargs): + web.RequestHandler.__init__(self, application, request, **kwargs) + + def initialize(self, url, model_name): + self._client = grpcClient(base_url=url, model_name=model_name) + + @tornado.gen.coroutine + def post(self): + """ + POST METHOD + """ + body = self.request.body + remote_ip = self.request.remote_ip + start_time = time.time() + if not body: + out_json = {"errorCode": 4000101} + result_str = json.dumps(out_json, ensure_ascii=False) + logging.warning( + f"request receieved from remote ip:{remote_ip}, body=None,\ + result={result_str}, time_cost={time.time() - start_time : 0.5f}" + ) + self.write(result_str) + else: + body = json.loads(body) + logging.info( + f"request receieved from remote ip:{remote_ip}, body={json.dumps(body, ensure_ascii=False)}" + ) + err = self.valid_body(body) + if err is None: + data = yield self.run_req(body) + if data is None: + out_json = create_error_response(4000102, + "result is empty") + else: + out_json = {"outputs": [data], "status": 0} + result_str = json.dumps(out_json, ensure_ascii=False) + else: + result_str = json.dumps(err, ensure_ascii=False) + + logging.info( + f"request returned, result={result_str}, time_cost={time.time() - start_time : 0.5f}" + ) + self.write(result_str) + + def valid_body(self, request): + """ + Check whether the request body is legal + + Args: + request (dict): + + Returns: + Union[dict, None]: + If the request body is valid, return None; + otherwise, return json with the error message + """ + if request['model'] != self._client._model_name: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "current model is not currently supported") + if 'n' in request and request['n'] != 1: + return create_error_response(HTTPStatus.BAD_REQUEST, + "n only support 1") + if 'logit_bias' in request and request['logit_bias'] is not None: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "logit_bias is not currently supported") + if 'functions' in request and request['functions'] is not None: + return create_error_response( + HTTPStatus.BAD_REQUEST, "functions is not currently supported") + if 'function_call' in request and request['function_call'] is not None: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "function_call is not currently supported") + return None + + def gen_prompt(self, request): + conv = get_conv_template(request['model']) + if isinstance(request['messages'], str): + prompt = request['messages'] + else: + for message in request['messages']: + msg_role = message["role"] + if msg_role == "system": + conv.system_message = message["content"] + elif msg_role == "user": + conv.append_message(conv.roles[0], message["content"]) + elif msg_role == "assistant": + conv.append_message(conv.roles[1], message["content"]) + else: + raise ValueError(f"Unknown role: {msg_role}") + + # Add a blank message for the assistant. + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + return prompt + + @run_on_executor + def run_req(self, body): + req_id = random.randint(0, 100000) + prompt = self.gen_prompt(body) + result = self._client.generate( + request_id=str(req_id), + prompt=prompt, + top_p=parse_parameters(body, 'top_p', 0.0), + temperature=parse_parameters(body, 'temperature', 1.0), + max_dec_len=parse_parameters(body, 'max_tokens', 1024), + frequency_score=parse_parameters(body, 'frequency_penalty', 0.99), + presence_score=parse_parameters(body, 'presence_penalty', 0.0), + stream=parse_parameters(body, 'stream', False)) + return result + + +class CompletionApiHandler(web.RequestHandler): + """ + This handler provides OpenAI's Completion API。 + + See https://platform.openai.com/docs/api-reference/completions/create + for the API specification. This API mimics the OpenAI Completion API. + + NOTE: Currently we do not support the following features: + - best_of (currently only support 1) + - n (currently only support 1) + - echo (not currently support getting the logprobs of prompt tokens) + - suffix (the language models we currently support do not support + suffix) + - logit_bias + - logprobs + - stop (currently support token id) + """ + executor = ThreadPoolExecutor(20) + + def __init__(self, application, request, **kwargs): + web.RequestHandler.__init__(self, application, request, **kwargs) + + def initialize(self, url, model_name): + self._client = grpcClient(base_url=url, model_name=model_name) + + @tornado.gen.coroutine + def post(self): + """ + POST METHOD + """ + body = self.request.body + remote_ip = self.request.remote_ip + start_time = time.time() + if not body: + out_json = {"errorCode": 4000101} + result_str = json.dumps(out_json, ensure_ascii=False) + logging.warning( + f"request receieved from remote ip:{remote_ip}, body=None,\ + result={result_str}, time_cost={time.time() - start_time : 0.5f}" + ) + + self.write(result_str) + else: + body = json.loads(body) + logging.info( + f"request receieved from remote ip:{remote_ip}, body={json.dumps(body, ensure_ascii=False)}" + ) + err = self.valid_body(body) + if err is None: + data = yield self.run_req(body) + if data is None: + out_json = create_error_response(4000102, + "result is empty") + else: + out_json = {"outputs": [data], "status": 0} + result_str = json.dumps(out_json, ensure_ascii=False) + else: + result_str = json.dumps(err, ensure_ascii=False) + + logging.info( + f"request returned, result={result_str}, time_cost={time.time() - start_time : 0.5f}" + ) + self.write(result_str) + + def valid_body(self, request): + """ + Check whether the request body is legal + + Args: + request (dict): + + Returns: + Union[dict, None]: + If the request body is valid, return None; + otherwise, return json with the error message + """ + if request['model'] != self._client._model_name: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "current model is not currently supported") + if 'n' in request and request['n'] != 1: + return create_error_response(HTTPStatus.BAD_REQUEST, + "n only support 1") + if 'best_of' in request and request['best_of'] != 1: + return create_error_response(HTTPStatus.BAD_REQUEST, + "best_of only support 1") + if 'echo' in request and request['echo']: + return create_error_response(HTTPStatus.BAD_REQUEST, + "not suport echo") + if 'suffix' in request and request['suffix'] is not None: + return create_error_response(HTTPStatus.BAD_REQUEST, + "not suport suffix") + if 'logit_bias' in request and request['logit_bias'] is not None: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "logit_bias is not currently supported") + if 'logprobs' in request and request['logprobs'] is not None: + return create_error_response( + HTTPStatus.BAD_REQUEST, "logprobs is not currently supported") + + return None + + @run_on_executor + def run_req(self, body): + req_id = random.randint(0, 100000) + result = self._client.generate( + request_id=str(req_id), + prompt=body['prompt'], + top_p=parse_parameters(body, 'top_p', 0.0), + temperature=parse_parameters(body, 'temperature', 1.0), + max_dec_len=parse_parameters(body, 'max_tokens', 1024), + frequency_score=parse_parameters(body, 'frequency_penalty', 0.99), + presence_score=parse_parameters(body, 'presence_penalty', 0.0), + stream=parse_parameters(body, 'stream', False)) + return result + + +if __name__ == '__main__': + args = parse.parse_args() + port = args.port + app = web.Application([("/v1/completions", CompletionApiHandler, + dict(url=args.url, model_name=args.model)), + ("/v1/chat/completions", ChatCompletionApiHandler, + dict(url=args.url, model_name=args.model))]) + + logger = logging.getLogger() + logger.setLevel(logging.INFO) + + formatter = tornado.log.LogFormatter( + fmt= + '%(levelname)-8s %(asctime)s %(process)-5s %(filename)s[line:%(lineno)d] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + + file_handler = TimedRotatingFileHandler(filename='log/server.log', + when='D', + interval=3, + backupCount=90, + encoding='utf-8', + delay=False) + file_handler.setFormatter(formatter) + + logger.addHandler(file_handler) + app.listen(port) + print("Server started") + logging.info(f"Server started at port:{port}") + tornado.ioloop.IOLoop.current().start() \ No newline at end of file diff --git a/llm/test/README.md b/llm/test/README.md new file mode 100644 index 0000000000..7f10ac70ba --- /dev/null +++ b/llm/test/README.md @@ -0,0 +1,9 @@ +## 客户端 + +- 支持open ai 两种api :ChatCompletion 与 Completion + +### 使用方式 +- 这里提供了fastdelopy 客户端的示例demo,参考test 目录下 test_client.py +- openai 两个接口示例代码在test 目录下 test_openai.py + + diff --git a/llm/test/test_client.py b/llm/test/test_client.py new file mode 100644 index 0000000000..0e7c22bb19 --- /dev/null +++ b/llm/test/test_client.py @@ -0,0 +1,21 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fastdeploy_llm.Client import grpcClient + +client = grpcClient(base_url="0.0.0.0:8812", + model_name="llama-ptuning", + timeout=100) +result = client.generate("Hello, how are you") +print(result) \ No newline at end of file diff --git a/llm/test/test_openai.py b/llm/test/test_openai.py new file mode 100644 index 0000000000..b0165ee34a --- /dev/null +++ b/llm/test/test_openai.py @@ -0,0 +1,54 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import openai +from fastdeploy_llm.Client import grpcClient + +model = "llama-ptuning" +port = 2001 +url = "0.0.0.0:8812" + +client = grpcClient(base_url=url, model_name=model, openai_port=port) + +# Modify OpenAI's API key and API base. +openai.api_key = "EMPTY" +openai.api_base = "http://0.0.0.0:" + str(port) + "/v1" + +# Completion API +completion = openai.Completion.create( + model=model, prompt="A robot may not injure a human being") + +print("Completion results:") +print(completion) + +# ChatCompletion API +chat_completion = openai.ChatCompletion.create( + model=model, + messages=[{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Who won the world series in 2020?" + }, { + "role": + "assistant", + "content": + "The Los Angeles Dodgers won the World Series in 2020." + }, { + "role": "user", + "content": "Where was it played?" + }]) +print("Chat completion results:") +print(chat_completion)