From 16e501d39382404dddbd563f946207dc32a92e98 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Fri, 20 Dec 2024 10:10:04 +0000 Subject: [PATCH] fix: Improve Gemini client error handling and add tests (#530) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add better error messages for API key configuration - Add comprehensive test coverage - Update google-generativeai version requirement - Add proper logging for debugging Fixes #530 Co-Authored-By: Erkin Alp Güney --- requirements.txt | 4 +- src/llm/gemini_client.py | 70 ++++++++++++++++++++++----------- tests/test_gemini_client.py | 77 +++++++++++++++++++++++++++++++++++++ 3 files changed, 127 insertions(+), 24 deletions(-) create mode 100644 tests/test_gemini_client.py diff --git a/requirements.txt b/requirements.txt index 91666960..311e6d57 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,8 +14,8 @@ pytest-playwright tiktoken ollama openai -anthropic -google-generativeai +anthropic>=0.8.0 +google-generativeai>=0.3.0 sqlmodel keybert GitPython diff --git a/src/llm/gemini_client.py b/src/llm/gemini_client.py index 0d566673..c148c3fd 100644 --- a/src/llm/gemini_client.py +++ b/src/llm/gemini_client.py @@ -2,32 +2,58 @@ from google.generativeai.types import HarmCategory, HarmBlockThreshold from src.config import Config +from src.logger import Logger + +logger = Logger() +config = Config() class Gemini: def __init__(self): - config = Config() api_key = config.get_gemini_api_key() - genai.configure(api_key=api_key) + if not api_key: + error_msg = ("Gemini API key not found in configuration. " + "Please add your Gemini API key to config.toml under [API_KEYS] " + "section as GEMINI = 'your-api-key'") + logger.error(error_msg) + raise ValueError(error_msg) + try: + genai.configure(api_key=api_key) + logger.info("Successfully initialized Gemini client") + except Exception as e: + error_msg = f"Failed to configure Gemini client: {str(e)}" + logger.error(error_msg) + raise ValueError(error_msg) def inference(self, model_id: str, prompt: str) -> str: - config = genai.GenerationConfig(temperature=0) - model = genai.GenerativeModel(model_id, generation_config=config) - # Set safety settings for the request - safety_settings = { - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, - # You can adjust other categories as needed - } - response = model.generate_content(prompt, safety_settings=safety_settings) try: - # Check if the response contains text - return response.text - except ValueError: - # If the response doesn't contain text, check if the prompt was blocked - print("Prompt feedback:", response.prompt_feedback) - # Also check the finish reason to see if the response was blocked - print("Finish reason:", response.candidates[0].finish_reason) - # If the finish reason was SAFETY, the safety ratings have more details - print("Safety ratings:", response.candidates[0].safety_ratings) - # Handle the error or return an appropriate message - return "Error: Unable to generate content Gemini API" + logger.info(f"Initializing Gemini model: {model_id}") + config = genai.GenerationConfig(temperature=0) + model = genai.GenerativeModel(model_id, generation_config=config) + + safety_settings = { + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, + } + + logger.info("Generating response with Gemini") + response = model.generate_content(prompt, safety_settings=safety_settings) + + try: + if response.text: + logger.info("Successfully generated response") + return response.text + else: + error_msg = f"Empty response from Gemini model {model_id}" + logger.error(error_msg) + raise ValueError(error_msg) + except ValueError: + logger.error("Failed to get response text") + logger.error(f"Prompt feedback: {response.prompt_feedback}") + logger.error(f"Finish reason: {response.candidates[0].finish_reason}") + logger.error(f"Safety ratings: {response.candidates[0].safety_ratings}") + return "Error: Unable to generate content with Gemini API" + + except Exception as e: + error_msg = f"Error during Gemini inference: {str(e)}" + logger.error(error_msg) + raise ValueError(error_msg) diff --git a/tests/test_gemini_client.py b/tests/test_gemini_client.py new file mode 100644 index 00000000..d4da231e --- /dev/null +++ b/tests/test_gemini_client.py @@ -0,0 +1,77 @@ +""" +Tests for Gemini client implementation. +""" +import pytest +from unittest.mock import MagicMock, patch +from src.llm.gemini_client import Gemini + +@pytest.fixture +def mock_config(): + with patch('src.llm.gemini_client.config') as mock: + mock.get_gemini_api_key.return_value = "test-api-key" + yield mock + +@pytest.fixture +def mock_genai(): + with patch('src.llm.gemini_client.genai') as mock: + yield mock + +@pytest.fixture +def gemini_client(mock_config, mock_genai): + return Gemini() + +def test_init_with_api_key(mock_config, mock_genai): + """Test client initialization with API key.""" + client = Gemini() + mock_genai.configure.assert_called_once_with(api_key="test-api-key") + +def test_init_without_api_key(mock_config, mock_genai): + """Test client initialization without API key.""" + mock_config.get_gemini_api_key.return_value = None + with pytest.raises(ValueError, match="Gemini API key not found in configuration"): + Gemini() + +def test_init_config_failure(mock_config, mock_genai): + """Test handling of configuration failure.""" + mock_genai.configure.side_effect = Exception("Test error") + with pytest.raises(ValueError, match="Failed to configure Gemini client: Test error"): + Gemini() + +def test_inference_success(mock_genai, gemini_client): + """Test successful text generation.""" + mock_model = MagicMock() + mock_response = MagicMock() + mock_response.text = "Generated response" + mock_model.generate_content.return_value = mock_response + mock_genai.GenerativeModel.return_value = mock_model + + response = gemini_client.inference("gemini-pro", "Test prompt") + assert response == "Generated response" + mock_model.generate_content.assert_called_once_with("Test prompt", safety_settings={ + mock_genai.types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: mock_genai.types.HarmBlockThreshold.BLOCK_NONE, + mock_genai.types.HarmCategory.HARM_CATEGORY_HARASSMENT: mock_genai.types.HarmBlockThreshold.BLOCK_NONE, + }) + +def test_inference_empty_response(mock_genai, gemini_client): + """Test handling of empty response.""" + mock_model = MagicMock() + mock_response = MagicMock() + mock_response.text = None + mock_model.generate_content.return_value = mock_response + mock_genai.GenerativeModel.return_value = mock_model + + with pytest.raises(ValueError, match="Error: Unable to generate content Gemini API"): + gemini_client.inference("gemini-pro", "Test prompt") + +def test_inference_error(mock_genai, gemini_client): + """Test handling of inference error.""" + mock_model = MagicMock() + mock_model.generate_content.side_effect = Exception("Test error") + mock_genai.GenerativeModel.return_value = mock_model + + with pytest.raises(ValueError, match="Error: Unable to generate content Gemini API"): + gemini_client.inference("gemini-pro", "Test prompt") + +def test_str_representation(gemini_client): + """Test string representation.""" + assert str(gemini_client) == "Gemini"