From fb44eb086c8f2e5b48c411f8901fb9dd4a889be5 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sun, 8 Dec 2024 20:30:07 +0000 Subject: [PATCH] simplify bench --- .github/bench.py | 150 ++++++++++++++++++++++------------------------- 1 file changed, 69 insertions(+), 81 deletions(-) diff --git a/.github/bench.py b/.github/bench.py index e06ecdff2..734eaf2a0 100644 --- a/.github/bench.py +++ b/.github/bench.py @@ -19,88 +19,76 @@ async def measure_performance(api_endpoint: str, prompt: str) -> Dict[str, Any]: Returns: Dict[str, Any]: A dictionary containing performance metrics or error information. """ - model = os.environ.get('model') - results: Dict[str, Any] = {'model': model, 'run_id': os.environ.get('GITHUB_RUN_ID')} - results['configuration'] = json.loads(os.environ.get('HARDWARE_CONFIG')) - - # Get prompt length in tokens - async with aiohttp.ClientSession() as session: - try: - request_payload = { + model = os.environ.get('model', 'llama-3.2-1b') + + results = { + 'model': model, + 'run_id': os.environ.get('GITHUB_RUN_ID', 'unknown'), + 'configuration': json.loads(os.environ.get('HARDWARE_CONFIG', '{}')) + } + + # Get token count + session = aiohttp.ClientSession() + try: + response = await session.post( + "http://localhost:52415/v1/chat/token/encode", + json={ "model": model, "messages": [{"role": "user", "content": prompt}] } - async with session.post( - "http://localhost:52415/v1/chat/token/encode", - json=request_payload - ) as response: - token_data = await response.json() - prompt_tokens = token_data.get('num_tokens', 0) - print(f"Prompt length: {prompt_tokens} tokens", flush=True) - except Exception as e: - print(f"Failed to get prompt length: {e}", flush=True) - prompt_tokens = 0 - results['prompt_len'] = prompt_tokens - - request_payload = { - "model": model, - "messages": [{"role": "user", "content": prompt}], - "temperature": 0, - "stream": True - } + ) + response.raise_for_status() + token_data = await response.json() + results['prompt_len'] = token_data['num_tokens'] + except Exception as e: + await session.close() + raise RuntimeError(f"Failed to get token count: {str(e)}") - async with aiohttp.ClientSession() as session: - try: - start_time = time.time() - first_token_time = None - total_tokens = 0 - - async with session.post(api_endpoint, json=request_payload) as response: - if response.status != 200: - results["error"] = f"HTTP {response.status}: {response.reason}" - return results - - async for raw_line in response.content: - line = raw_line.decode('utf-8').strip() - if not line or not line.startswith('data: '): - continue - - line_content = line[6:] # Remove 'data: ' prefix - if line_content == '[DONE]': - break - - try: - chunk = json.loads(line_content) - choice = chunk.get('choices', [{}])[0] - content = choice.get('delta', {}).get('content') - - if content: - if first_token_time is None: - first_token_time = time.time() - results['ttft'] = first_token_time - start_time - results['prompt_tps'] = prompt_tokens/results['ttft'] - - total_tokens += 1 - except json.JSONDecodeError: - # Log or handle malformed JSON if necessary - continue - - end_time = time.time() - total_time = end_time - start_time - - if total_tokens > 0: - results.update({ - "generation_tps": total_tokens / total_time, - "response_len": total_tokens, - "total_time": total_time - }) - else: - results["error"] = "No tokens were generated" - - except aiohttp.ClientError as e: - results["error"] = f"Client error: {e}" - except Exception as e: - results["error"] = f"Unexpected error: {e}" + # Measure completion performance + try: + start_time = time.time() + response = await session.post( + api_endpoint, + json={ + "model": model, + "messages": [{"role": "user", "content": prompt}], + "temperature": 0, + "stream": True + } + ) + response.raise_for_status() + + first_token_time = None + total_tokens = 0 + + async for line in response.content.iter_chunks(): + line = line[0].decode('utf-8').strip() + if not line.startswith('data: '): + continue + + data = json.loads(line[6:]) # Skip 'data: ' prefix + if content := data.get('choices', [{}])[0].get('delta', {}).get('content'): + print(f"Received content: {content}", flush=True) + if first_token_time is None: + first_token_time = time.time() + ttft = first_token_time - start_time + results.update({ + 'ttft': ttft, + 'prompt_tps': results['prompt_len'] / ttft + }) + total_tokens += 1 + + total_time = time.time() - start_time + results.update({ + 'generation_tps': total_tokens / total_time, + 'response_len': total_tokens, + 'total_time': total_time + }) + + except Exception as e: + raise RuntimeError(f"Performance measurement failed: {str(e)}") + finally: + await session.close() return results @@ -122,13 +110,13 @@ async def main() -> None: aws_secret_access_key=os.environ.get('aws_secret_key') ) job_name = os.environ.get('GITHUB_JOB') - + # Create S3 key with timestamp and commit info now = datetime.utcnow() timestamp = now.strftime('%H-%M-%S') commit_sha = os.environ.get('GITHUB_SHA', 'unknown')[:7] s3_key = f"{job_name}/{now.year}/{now.month}/{now.day}/{timestamp}_{commit_sha}.json" - + # Upload to S3 s3_client.put_object( Bucket='exo-benchmarks', @@ -146,4 +134,4 @@ async def main() -> None: if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) \ No newline at end of file