Skip to content

Commit

Permalink
add preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
tybalex committed Jun 10, 2024
1 parent cb77ad8 commit 228a19e
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ lm-format-enforcer == 0.10.1
outlines >= 0.0.43 # Requires torch >= 2.1.0
typing_extensions
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
llama_tools >= 0.1.22
3 changes: 2 additions & 1 deletion vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: List[ChatCompletionMessageParam]
# messages: List[ChatCompletionMessageParam] # TODO: figure out why
messages: List[dict]
model: str
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
Expand Down
11 changes: 11 additions & 0 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.utils import random_uuid
from llama_tools import preprocess_input, postprocess_output

logger = init_logger(__name__)

Expand Down Expand Up @@ -210,11 +211,18 @@ async def create_chat_completion(
conversation: List[ConversationMessage] = []
image_futures: List[Awaitable[ImagePixelData]] = []

raw_msgs = request.messages
if request.tools:
print("==================tools====================")
raw_msgs = preprocess_input(msgs=raw_msgs, tools=request.tools)

for msg in request.messages:
chat_parsed_result = self._parse_chat_message_content(msg)

conversation.extend(chat_parsed_result.messages)
image_futures.extend(chat_parsed_result.image_futures)

conversation = raw_msgs

prompt = self.tokenizer.apply_chat_template(
conversation=conversation,
Expand Down Expand Up @@ -488,6 +496,7 @@ async def chat_completion_full_generator(
choices = []

role = self.get_chat_request_role(request)
print("========================output========================")
for output in final_res.outputs:
token_ids = output.token_ids
top_logprobs = output.logprobs
Expand All @@ -501,6 +510,8 @@ async def chat_completion_full_generator(
else:
logprobs = None

# TODO: use llama_tools to parse the output.text
print(output)
if request.tool_choice and type(
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
message = ChatMessage(
Expand Down

0 comments on commit 228a19e

Please sign in to comment.