Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support batching #5

Merged
merged 2 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions app/services/llama_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio
import aiohttp
import requests
import json
from typing import List

from ..config.settings import settings, logger

Expand All @@ -21,30 +24,37 @@ class LlamaCppClient:
"mirostat": 0,
"mirostat_tau": 5.0,
"mirostat_eta": 0.1,
"stream": False
"stream": False,
}

def __init__(self):
self.base_url = settings.host
self.system_prompt = settings.system_prompt
self.prompt = settings.model_prompt

def chat(self, context, message, stream, params=None):
params = params or self.DEFAULT_PARAMS.copy()
params["stream"] = stream
async def _call_model_parallel(self, url, payloads) -> List[dict]:
"""This function calls the model at url in parallel with multiple payloads"""
async with aiohttp.ClientSession() as session:

async def fetch(url, data):
async with session.post(url, json=data) as response:
return await response.json()

return await asyncio.gather(*[fetch(url, data) for data in payloads])

prompt = self.prompt.format(
system_prompt=self.system_prompt,
context=context,
message=message
)
def _build_payload(self, context, message, params):
prompt = self.prompt.format(system_prompt=self.system_prompt, context=context, message=message)

data = {
"prompt": prompt,
**params
}
data = {"prompt": prompt, **params}

return data

def chat(self, context, message, stream, params=None) -> dict:
params = params or self.DEFAULT_PARAMS.copy()
params["stream"] = stream

logger.info(f"Sending request to llama.cpp server with prompt: {prompt}")
data = self._build_payload(context, message, params)
# logger.info(f"Sending request to llama.cpp server with prompt: {prompt}")

try:
if params["stream"]:
Expand All @@ -55,18 +65,33 @@ def chat(self, context, message, stream, params=None):
logger.error(f"Request failed: {e}")
raise

def chat_parallel(self, contexts, messages, params=None) -> List[dict]:
params = params or self.DEFAULT_PARAMS.copy()

payloads = []
for context, message in zip(contexts, messages):
payloads.append(self._build_payload(context, message, params))

logger.info(f"Sending parallel requests to llama.cpp server with {len(payloads)} payloads")
url = f"{self.base_url}/completion"
try:
return asyncio.run(self._call_model_parallel(url, payloads))
except requests.RequestException as e:
logger.error(f"Parallel requests failed: {e}")
raise

def _stream_response(self, data):
response = requests.post(f"{self.base_url}/completion", json=data, stream=True)
response.raise_for_status()

for line in response.iter_lines():
if line:
decoded_line = line.decode('utf-8')
if decoded_line.startswith('data: '):
decoded_line = line.decode("utf-8")
if decoded_line.startswith("data: "):
content = json.loads(decoded_line[6:])
if content.get('stop'):
if content.get("stop"):
break
chunk = content.get('content', '')
chunk = content.get("content", "")
yield chunk

def _get_response(self, data):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from tqdm import tqdm


def generate_responses(params, contexts, questions, llm_client):
def generate_responses(params, contexts, questions, llm_client, batch_size=4):
output_file_name = params["output_file"]

with open(f"./data/{output_file_name}", "w", newline="") as output_file:
Expand Down Expand Up @@ -30,31 +30,36 @@ def generate_responses(params, contexts, questions, llm_client):
dict_writer = csv.DictWriter(output_file, fieldnames=fieldnames)
dict_writer.writeheader()

for context_size, context in tqdm(contexts.items()):
question = questions[context_size]
response = llm_client.chat(user_prompt=context, question=question)
full_response = response.json()
context_keys = list(contexts.keys())
for i in tqdm(range(0, len(contexts), batch_size)):
batch_questions = [questions[context_size] for context_size in context_keys[i : i + batch_size]]
context_texts = [contexts[context_size] for context_size in context_keys[i : i + batch_size]]
responses = llm_client.chat(user_prompts=context_texts, questions=batch_questions)
for j, full_response in enumerate(responses):
context_size = context_keys[i + j]
context = context_texts[j]
question = batch_questions[j]

result = {
"model": full_response["model"],
"context_size": context_size,
"total_cores": params["total_cores"],
"prompt": context,
"question": question,
"response": full_response["content"],
"temperature": params["temperature"],
"n_predict": params["tokens_to_predict"],
"tokens_predicted": full_response["tokens_predicted"],
"tokens_evaluated": full_response["tokens_evaluated"],
"prompt_n": full_response["timings"]["prompt_n"],
"prompt_ms": full_response["timings"]["prompt_ms"],
"prompt_per_token_ms": full_response["timings"]["prompt_per_token_ms"],
"prompt_per_second": full_response["timings"]["prompt_per_second"],
"predicted_n": full_response["timings"]["predicted_n"],
"predicted_ms": full_response["timings"]["predicted_ms"],
"predicted_per_token_ms": full_response["timings"]["predicted_per_token_ms"],
"predicted_per_second": full_response["timings"]["predicted_per_second"],
}
result = {
"model": full_response["model"],
"context_size": context_size,
"total_cores": params["total_cores"],
"prompt": context,
"question": question,
"response": full_response["content"],
"temperature": params["temperature"],
"n_predict": params["tokens_to_predict"],
"tokens_predicted": full_response["tokens_predicted"],
"tokens_evaluated": full_response["tokens_evaluated"],
"prompt_n": full_response["timings"]["prompt_n"],
"prompt_ms": full_response["timings"]["prompt_ms"],
"prompt_per_token_ms": full_response["timings"]["prompt_per_token_ms"],
"prompt_per_second": full_response["timings"]["prompt_per_second"],
"predicted_n": full_response["timings"]["predicted_n"],
"predicted_ms": full_response["timings"]["predicted_ms"],
"predicted_per_token_ms": full_response["timings"]["predicted_per_token_ms"],
"predicted_per_second": full_response["timings"]["predicted_per_second"],
}

dict_writer.writerow(result)
output_file.flush()
dict_writer.writerow(result)
output_file.flush()
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import requests
from typing import List
import aiohttp
import asyncio


class LlamaCppClient:
Expand All @@ -12,18 +14,32 @@ def __init__(self, settings: dict):
self.stop = settings.get("stop")
self.stream = settings.get("stream")

def chat(self, user_prompt: str, question: str):
async def _call_model(self, url, pl):
async with aiohttp.ClientSession() as session:

async def fetch(url, data, i):
async with session.post(url, json=data) as response:
j = await response.json()
return j

return await asyncio.gather(*[fetch(url, data, i) for i, data in enumerate(pl)])

def chat(self, user_prompts: List[str], questions: List[str]):
params = {
"n_predict": self.n_predict,
"temperature": self.temperature,
"stop": self.stop,
"stream": self.stream,
"cache_prompt": False,
}
pl = []
for user_prompt, question in zip(user_prompts, questions):
model_prompt = self.model_prompt.format(system_prompt=self.system_prompt, user_prompt=user_prompt, question=question)

model_prompt = self.model_prompt.format(system_prompt=self.system_prompt, user_prompt=user_prompt, question=question)
data = {"prompt": model_prompt, **params}
pl.append(data)

data = {"prompt": model_prompt, **params}
url = f"{self.base_url}/completion"
responses = asyncio.run(self._call_model(url, pl))

response = requests.post(f"{self.base_url}/completion", json=data)
return response
return responses
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
elasticsearch==8.15.0
fastapi==0.112.2
sentence-transformers==3.0.1
python-multipart
requests==2.32.3
ruff==0.6.3
uvicorn