Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add better request handling to openai-compatible server #1549

Merged
merged 5 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion interpreter/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def default_system_message(self):
):
system_message = system_message.replace(
"</SYSTEM_CAPABILITY>",
"* For fast web searches (like up-to-date docs) curl https://api.openinterpreter.com/v0/browser/search?query=your+search+query\n</SYSTEM_CAPABILITY>",
"* For any web search requests, curl https://api.openinterpreter.com/v0/browser/search?query=your+search+query\n</SYSTEM_CAPABILITY>",
)

# Update system prompt for Mac OS, if computer tool is enabled
Expand Down
31 changes: 14 additions & 17 deletions interpreter/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
from asyncio import CancelledError, Task


class ChatCompletionRequest(BaseModel):
Expand Down Expand Up @@ -35,14 +36,17 @@ def __init__(self, interpreter):
# Setup routes
self.app.post("/chat/completions")(self.chat_completion)


async def chat_completion(self, request: Request):
"""Main chat completion endpoint"""
body = await request.json()
if self.interpreter.debug:
print("Request body:", body)
try:
req = ChatCompletionRequest(**body)
except Exception as e:
print("Validation error:", str(e)) # Debug print
print("Request body:", body) # Print the request body
print("Validation error:", str(e))
print("Request body:", body)
raise

# Filter out system message
Expand Down Expand Up @@ -75,18 +79,6 @@ async def _stream_response(self):
delta["function_call"] = choice.delta.function_call
if choice.delta.tool_calls is not None:
pass
# Convert tool_calls to dict representation
# delta["tool_calls"] = [
# {
# "index": tool_call.index,
# "id": tool_call.id,
# "type": tool_call.type,
# "function": {
# "name": tool_call.function.name,
# "arguments": tool_call.function.arguments
# }
# } for tool_call in choice.delta.tool_calls
# ]

choices.append(
{
Expand All @@ -108,11 +100,16 @@ async def _stream_response(self):
data["system_fingerprint"] = chunk.system_fingerprint

yield f"data: {json.dumps(data)}\n\n"
except asyncio.CancelledError:
# Set stop flag when stream is cancelled
self.interpreter._stop_flag = True

except CancelledError:
# Handle cancellation gracefully
print("Request cancelled - cleaning up...")

raise
except Exception as e:
print(f"Error in stream: {str(e)}")
finally:
# Always send DONE message and cleanup
yield "data: [DONE]\n\n"

def run(self):
Expand Down
Loading