Skip to content

Commit

Permalink
add client
Browse files Browse the repository at this point in the history
  • Loading branch information
luotingdan committed Oct 16, 2023
1 parent 17a8ab9 commit 57c3715
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 1,085 deletions.
1,044 changes: 0 additions & 1,044 deletions llm/client/conversation.py

This file was deleted.

7 changes: 0 additions & 7 deletions llm/client/test_client.py

This file was deleted.

47 changes: 38 additions & 9 deletions llm/client/Client.py → llm/fastdeploy_llm/Client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import queue
import json
import sys
from functools import partial

import os
import time
import numpy as np
import subprocess
import tritonclient.grpc as grpcclient
from tritonclient.utils import *

import api_client

class UserData:
def __init__(self):
self._completed_requests = queue.Queue()
Expand All @@ -26,13 +44,15 @@ def __init__(
model_name: str,
model_version: str = "1",
timeout: int = 100,
openai_port: int = None
):
"""
Args:
base_url (`str`): inference server grpc url
model_name (`str`)
model_version (`str`): default "1"
timeout (`int`): inference timeout in seconds
openai_port (`int`)
"""
self._model_name = model_name
self._model_version = model_version
Expand All @@ -48,7 +68,13 @@ def __init__(
self.inputs = [grpcclient.InferInput("IN", [1], np_to_triton_dtype(np.object_))]
self.outputs = [grpcclient.InferRequestedOutput("OUT")]
self.has_init = False
self.user_data = UserData()
self.user_data = UserData()

if openai_port is not None:
pd_cmd = "python3 api_client.py --url {0} --port {1} --model {2}".format(base_url, openai_port, model_name)
subprocess.Popen(pd_cmd, shell=True, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, preexec_fn=os.setsid)
time.sleep(5)

def _verify_triton_state(self, triton_client):
if not triton_client.is_server_live():
Expand All @@ -71,10 +97,10 @@ def generate(
penalty_score: float = 1.0,
frequency_score: float = 0.99,
eos_token_id: int =2,
presence_score: float = 0.0
presence_score: float = 0.0,
stream: bool=False
):

#text = data_process(prompt)
req_dict = {
"text": prompt,
"topp": top_p,
Expand All @@ -89,7 +115,6 @@ def generate(
}

try:
# Establish stream
if not self.has_init:
self._client.start_stream(callback=partial(callback, self.user_data))
self.has_init = True
Expand All @@ -105,20 +130,24 @@ def generate(
inputs=self.inputs,
request_id=request_id,
outputs=self.outputs)
# Retrieve results...
completion = ""
if stream:
completion = []
else:
completion = ""
while True:
data_item = self.user_data._completed_requests.get(timeout=self.timeout)
if type(data_item) == InferenceServerException:
print('Exception:', 'status', data_item.status(), 'msg', data_item.message())
else:
results = data_item.as_numpy("OUT")[0]
data = json.loads(results)
completion += data["result"]
if stream:
completion.append(data["result"])
else:
completion += data["result"]
if data.get("is_end", False):
break
return completion
except Exception as e:
print(f"Client infer error: {e}")
raise e

1 change: 1 addition & 0 deletions llm/fastdeploy_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from .task import Task, BatchTask
from .config import Config
from . import utils
from . import Client
30 changes: 22 additions & 8 deletions llm/client/api_client.py → llm/fastdeploy_llm/api_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
from logging.handlers import TimedRotatingFileHandler
Expand All @@ -12,11 +26,10 @@
from tornado.concurrent import run_on_executor
from concurrent.futures import ThreadPoolExecutor

from conversation import *
from utils.conversation import *
from Client import *



parse = argparse.ArgumentParser()
parse.add_argument(
'--url', type=str, help='grpc server url')
Expand All @@ -25,9 +38,6 @@
parse.add_argument(
'--model', type=str, help='model name', default="model")




def parse_parameters(parameters_config, name, default_value):
if name not in parameters_config:
return default_value
Expand Down Expand Up @@ -90,7 +100,7 @@ def post(self):
out_json = create_error_response(4000102,"result is empty")
else:
out_json = {"outputs" : [data],
"status": 0}
"status": 0}
result_str = json.dumps(out_json, ensure_ascii=False)
else:
result_str = json.dumps(err, ensure_ascii=False)
Expand Down Expand Up @@ -160,11 +170,14 @@ def run_req(self, body):
temperature = parse_parameters(body, 'temperature', 1.0),
max_dec_len = parse_parameters(body, 'max_tokens', 1024),
frequency_score= parse_parameters(body, 'frequency_penalty', 0.99),
presence_score= parse_parameters(body, 'presence_penalty', 0.0)
presence_score= parse_parameters(body, 'presence_penalty', 0.0),
stream= parse_parameters(body, 'stream', False)
)
return result




class CompletionApiHandler(web.RequestHandler):
"""
This handler provides OpenAI's Completion API。
Expand Down Expand Up @@ -269,7 +282,8 @@ def run_req(self, body):
temperature = parse_parameters(body, 'temperature', 1.0),
max_dec_len = parse_parameters(body, 'max_tokens', 1024),
frequency_score= parse_parameters(body, 'frequency_penalty', 0.99),
presence_score= parse_parameters(body, 'presence_penalty', 0.0)
presence_score= parse_parameters(body, 'presence_penalty', 0.0),
stream= parse_parameters(body, 'stream', False)
)
return result

Expand Down
103 changes: 103 additions & 0 deletions llm/fastdeploy_llm/utils/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses


@dataclasses.dataclass
class Conversation:
name: str
system_template: str = "{system_message}"
system_message: str = ""
roles: list[str] = (("USER", "ASSISTANT"),)
messages: list[list[str]] = ()
sep_style: str = ""
sep: str = "\n"
sep2: str = None
stop_token_ids: list[int] = None

def get_prompt(self) -> str:
system_prompt = self.system_template.format(system_message=self.system_message)
if self.name == "llama-ptuning":
seps = [self.sep, self.sep2]
if self.system_message:
ret = system_prompt
else:
ret = "[INST] "
for i, (role, message) in enumerate(self.messages):
if message:
if i == 0:
ret += message + " "
else:
ret += role + " " + message + seps[i % 2]
else:
ret += role
return ret

def set_system_message(self, system_message: str):
self.system_message = system_message

def append_message(self, role: str, message: str):
self.messages.append([role, message])

def copy(self):
return Conversation(
name=self.name,
system_template=self.system_template,
system_message=self.system_message,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
stop_token_ids=self.stop_token_ids,
)


conv_templates: dict[str, Conversation] = {}


def register_conv_template(template: Conversation, override: bool = False):
if not override:
assert (
template.name not in conv_templates
), f"{template.name} has been registered."

conv_templates[template.name] = template


def get_conv_template(name: str) -> Conversation:
return conv_templates[name].copy()

register_conv_template(
Conversation(
name="llama-ptuning",
system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
roles=("[INST]", "[/INST]"),
sep=" ",
sep2=" </s><s>",
)
)



if __name__ == "__main__":
print("llama-ptuning template:")
conv = get_conv_template("llama-ptuning")
conv.set_system_message("You are a helpful, respectful and honest assistant.")
conv.append_message(conv.roles[0], "Hello!")
conv.append_message(conv.roles[1], "Hi!")
conv.append_message(conv.roles[0], "How are you?")
conv.append_message(conv.roles[1], None)
print(conv.get_prompt())
9 changes: 9 additions & 0 deletions llm/test/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
## 客户端

- 支持open ai 两种api :ChatCompletion 与 Completion

### 使用方式
- 这里提供了fastdelopy 客户端的示例demo,参考test 目录下 test_client.py
- openai 两个接口示例代码在test 目录下 test_openai.py


21 changes: 21 additions & 0 deletions llm/test/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from fastdeploy_llm.Client import grpcClient

client = grpcClient(base_url="0.0.0.0:8812",
model_name="llama-ptuning",
timeout= 100)
result = client.generate("Hello, how are you")
print(result)
Loading

0 comments on commit 57c3715

Please sign in to comment.