Skip to content

Commit

Permalink
simplify bench
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Dec 8, 2024
1 parent fe80749 commit fb44eb0
Showing 1 changed file with 69 additions and 81 deletions.
150 changes: 69 additions & 81 deletions .github/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,88 +19,76 @@ async def measure_performance(api_endpoint: str, prompt: str) -> Dict[str, Any]:
Returns:
Dict[str, Any]: A dictionary containing performance metrics or error information.
"""
model = os.environ.get('model')
results: Dict[str, Any] = {'model': model, 'run_id': os.environ.get('GITHUB_RUN_ID')}
results['configuration'] = json.loads(os.environ.get('HARDWARE_CONFIG'))

# Get prompt length in tokens
async with aiohttp.ClientSession() as session:
try:
request_payload = {
model = os.environ.get('model', 'llama-3.2-1b')

results = {
'model': model,
'run_id': os.environ.get('GITHUB_RUN_ID', 'unknown'),
'configuration': json.loads(os.environ.get('HARDWARE_CONFIG', '{}'))
}

# Get token count
session = aiohttp.ClientSession()
try:
response = await session.post(
"http://localhost:52415/v1/chat/token/encode",
json={
"model": model,
"messages": [{"role": "user", "content": prompt}]
}
async with session.post(
"http://localhost:52415/v1/chat/token/encode",
json=request_payload
) as response:
token_data = await response.json()
prompt_tokens = token_data.get('num_tokens', 0)
print(f"Prompt length: {prompt_tokens} tokens", flush=True)
except Exception as e:
print(f"Failed to get prompt length: {e}", flush=True)
prompt_tokens = 0
results['prompt_len'] = prompt_tokens

request_payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0,
"stream": True
}
)
response.raise_for_status()
token_data = await response.json()
results['prompt_len'] = token_data['num_tokens']
except Exception as e:
await session.close()
raise RuntimeError(f"Failed to get token count: {str(e)}")

async with aiohttp.ClientSession() as session:
try:
start_time = time.time()
first_token_time = None
total_tokens = 0

async with session.post(api_endpoint, json=request_payload) as response:
if response.status != 200:
results["error"] = f"HTTP {response.status}: {response.reason}"
return results

async for raw_line in response.content:
line = raw_line.decode('utf-8').strip()
if not line or not line.startswith('data: '):
continue

line_content = line[6:] # Remove 'data: ' prefix
if line_content == '[DONE]':
break

try:
chunk = json.loads(line_content)
choice = chunk.get('choices', [{}])[0]
content = choice.get('delta', {}).get('content')

if content:
if first_token_time is None:
first_token_time = time.time()
results['ttft'] = first_token_time - start_time
results['prompt_tps'] = prompt_tokens/results['ttft']

total_tokens += 1
except json.JSONDecodeError:
# Log or handle malformed JSON if necessary
continue

end_time = time.time()
total_time = end_time - start_time

if total_tokens > 0:
results.update({
"generation_tps": total_tokens / total_time,
"response_len": total_tokens,
"total_time": total_time
})
else:
results["error"] = "No tokens were generated"

except aiohttp.ClientError as e:
results["error"] = f"Client error: {e}"
except Exception as e:
results["error"] = f"Unexpected error: {e}"
# Measure completion performance
try:
start_time = time.time()
response = await session.post(
api_endpoint,
json={
"model": model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0,
"stream": True
}
)
response.raise_for_status()

first_token_time = None
total_tokens = 0

async for line in response.content.iter_chunks():
line = line[0].decode('utf-8').strip()
if not line.startswith('data: '):
continue

data = json.loads(line[6:]) # Skip 'data: ' prefix
if content := data.get('choices', [{}])[0].get('delta', {}).get('content'):
print(f"Received content: {content}", flush=True)
if first_token_time is None:
first_token_time = time.time()
ttft = first_token_time - start_time
results.update({
'ttft': ttft,
'prompt_tps': results['prompt_len'] / ttft
})
total_tokens += 1

total_time = time.time() - start_time
results.update({
'generation_tps': total_tokens / total_time,
'response_len': total_tokens,
'total_time': total_time
})

except Exception as e:
raise RuntimeError(f"Performance measurement failed: {str(e)}")
finally:
await session.close()

return results

Expand All @@ -122,13 +110,13 @@ async def main() -> None:
aws_secret_access_key=os.environ.get('aws_secret_key')
)
job_name = os.environ.get('GITHUB_JOB')

# Create S3 key with timestamp and commit info
now = datetime.utcnow()
timestamp = now.strftime('%H-%M-%S')
commit_sha = os.environ.get('GITHUB_SHA', 'unknown')[:7]
s3_key = f"{job_name}/{now.year}/{now.month}/{now.day}/{timestamp}_{commit_sha}.json"

# Upload to S3
s3_client.put_object(
Bucket='exo-benchmarks',
Expand All @@ -146,4 +134,4 @@ async def main() -> None:


if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())

0 comments on commit fb44eb0

Please sign in to comment.