Skip to content

Commit

Permalink
Merge pull request #5 from fastenhealth/speed_batching
Browse files Browse the repository at this point in the history
Support batching
  • Loading branch information
dgbaenar authored Sep 4, 2024
2 parents 6ed90b3 + 2c47611 commit 3ddf7e4
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 51 deletions.
61 changes: 43 additions & 18 deletions app/services/llama_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio
import aiohttp
import requests
import json
from typing import List

from ..config.settings import settings, logger

Expand All @@ -21,30 +24,37 @@ class LlamaCppClient:
"mirostat": 0,
"mirostat_tau": 5.0,
"mirostat_eta": 0.1,
"stream": False
"stream": False,
}

def __init__(self):
self.base_url = settings.host
self.system_prompt = settings.system_prompt
self.prompt = settings.model_prompt

def chat(self, context, message, stream, params=None):
params = params or self.DEFAULT_PARAMS.copy()
params["stream"] = stream
async def _call_model_parallel(self, url, payloads) -> List[dict]:
"""This function calls the model at url in parallel with multiple payloads"""
async with aiohttp.ClientSession() as session:

async def fetch(url, data):
async with session.post(url, json=data) as response:
return await response.json()

return await asyncio.gather(*[fetch(url, data) for data in payloads])

prompt = self.prompt.format(
system_prompt=self.system_prompt,
context=context,
message=message
)
def _build_payload(self, context, message, params):
prompt = self.prompt.format(system_prompt=self.system_prompt, context=context, message=message)

data = {
"prompt": prompt,
**params
}
data = {"prompt": prompt, **params}

return data

def chat(self, context, message, stream, params=None) -> dict:
params = params or self.DEFAULT_PARAMS.copy()
params["stream"] = stream

logger.info(f"Sending request to llama.cpp server with prompt: {prompt}")
data = self._build_payload(context, message, params)
# logger.info(f"Sending request to llama.cpp server with prompt: {prompt}")

try:
if params["stream"]:
Expand All @@ -55,18 +65,33 @@ def chat(self, context, message, stream, params=None):
logger.error(f"Request failed: {e}")
raise

def chat_parallel(self, contexts, messages, params=None) -> List[dict]:
params = params or self.DEFAULT_PARAMS.copy()

payloads = []
for context, message in zip(contexts, messages):
payloads.append(self._build_payload(context, message, params))

logger.info(f"Sending parallel requests to llama.cpp server with {len(payloads)} payloads")
url = f"{self.base_url}/completion"
try:
return asyncio.run(self._call_model_parallel(url, payloads))
except requests.RequestException as e:
logger.error(f"Parallel requests failed: {e}")
raise

def _stream_response(self, data):
response = requests.post(f"{self.base_url}/completion", json=data, stream=True)
response.raise_for_status()

for line in response.iter_lines():
if line:
decoded_line = line.decode('utf-8')
if decoded_line.startswith('data: '):
decoded_line = line.decode("utf-8")
if decoded_line.startswith("data: "):
content = json.loads(decoded_line[6:])
if content.get('stop'):
if content.get("stop"):
break
chunk = content.get('content', '')
chunk = content.get("content", "")
yield chunk

def _get_response(self, data):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from tqdm import tqdm


def generate_responses(params, contexts, questions, llm_client):
def generate_responses(params, contexts, questions, llm_client, batch_size=4):
output_file_name = params["output_file"]

with open(f"./data/{output_file_name}", "w", newline="") as output_file:
Expand Down Expand Up @@ -30,31 +30,36 @@ def generate_responses(params, contexts, questions, llm_client):
dict_writer = csv.DictWriter(output_file, fieldnames=fieldnames)
dict_writer.writeheader()

for context_size, context in tqdm(contexts.items()):
question = questions[context_size]
response = llm_client.chat(user_prompt=context, question=question)
full_response = response.json()
context_keys = list(contexts.keys())
for i in tqdm(range(0, len(contexts), batch_size)):
batch_questions = [questions[context_size] for context_size in context_keys[i : i + batch_size]]
context_texts = [contexts[context_size] for context_size in context_keys[i : i + batch_size]]
responses = llm_client.chat(user_prompts=context_texts, questions=batch_questions)
for j, full_response in enumerate(responses):
context_size = context_keys[i + j]
context = context_texts[j]
question = batch_questions[j]

result = {
"model": full_response["model"],
"context_size": context_size,
"total_cores": params["total_cores"],
"prompt": context,
"question": question,
"response": full_response["content"],
"temperature": params["temperature"],
"n_predict": params["tokens_to_predict"],
"tokens_predicted": full_response["tokens_predicted"],
"tokens_evaluated": full_response["tokens_evaluated"],
"prompt_n": full_response["timings"]["prompt_n"],
"prompt_ms": full_response["timings"]["prompt_ms"],
"prompt_per_token_ms": full_response["timings"]["prompt_per_token_ms"],
"prompt_per_second": full_response["timings"]["prompt_per_second"],
"predicted_n": full_response["timings"]["predicted_n"],
"predicted_ms": full_response["timings"]["predicted_ms"],
"predicted_per_token_ms": full_response["timings"]["predicted_per_token_ms"],
"predicted_per_second": full_response["timings"]["predicted_per_second"],
}
result = {
"model": full_response["model"],
"context_size": context_size,
"total_cores": params["total_cores"],
"prompt": context,
"question": question,
"response": full_response["content"],
"temperature": params["temperature"],
"n_predict": params["tokens_to_predict"],
"tokens_predicted": full_response["tokens_predicted"],
"tokens_evaluated": full_response["tokens_evaluated"],
"prompt_n": full_response["timings"]["prompt_n"],
"prompt_ms": full_response["timings"]["prompt_ms"],
"prompt_per_token_ms": full_response["timings"]["prompt_per_token_ms"],
"prompt_per_second": full_response["timings"]["prompt_per_second"],
"predicted_n": full_response["timings"]["predicted_n"],
"predicted_ms": full_response["timings"]["predicted_ms"],
"predicted_per_token_ms": full_response["timings"]["predicted_per_token_ms"],
"predicted_per_second": full_response["timings"]["predicted_per_second"],
}

dict_writer.writerow(result)
output_file.flush()
dict_writer.writerow(result)
output_file.flush()
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import requests
from typing import List
import aiohttp
import asyncio


class LlamaCppClient:
Expand All @@ -12,18 +14,32 @@ def __init__(self, settings: dict):
self.stop = settings.get("stop")
self.stream = settings.get("stream")

def chat(self, user_prompt: str, question: str):
async def _call_model(self, url, pl):
async with aiohttp.ClientSession() as session:

async def fetch(url, data, i):
async with session.post(url, json=data) as response:
j = await response.json()
return j

return await asyncio.gather(*[fetch(url, data, i) for i, data in enumerate(pl)])

def chat(self, user_prompts: List[str], questions: List[str]):
params = {
"n_predict": self.n_predict,
"temperature": self.temperature,
"stop": self.stop,
"stream": self.stream,
"cache_prompt": False,
}
pl = []
for user_prompt, question in zip(user_prompts, questions):
model_prompt = self.model_prompt.format(system_prompt=self.system_prompt, user_prompt=user_prompt, question=question)

model_prompt = self.model_prompt.format(system_prompt=self.system_prompt, user_prompt=user_prompt, question=question)
data = {"prompt": model_prompt, **params}
pl.append(data)

data = {"prompt": model_prompt, **params}
url = f"{self.base_url}/completion"
responses = asyncio.run(self._call_model(url, pl))

response = requests.post(f"{self.base_url}/completion", json=data)
return response
return responses
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
elasticsearch==8.15.0
fastapi==0.112.2
sentence-transformers==3.0.1
python-multipart
requests==2.32.3
ruff==0.6.3
uvicorn

0 comments on commit 3ddf7e4

Please sign in to comment.