From c6d84d226ba3e5d010b1d810c1b288379b73fc4d Mon Sep 17 00:00:00 2001 From: Pavel Tisnovsky Date: Fri, 4 Oct 2024 11:18:26 +0200 Subject: [PATCH] config: add api_version required for azure --- ols/app/models/config.py | 4 ++++ ols/constants.py | 2 ++ ols/src/llms/providers/azure_openai.py | 3 ++- .../valid_config_with_azure_openai_api_version.yaml | 2 +- tests/unit/app/models/test_config.py | 9 +++++++++ tests/unit/utils/test_config.py | 8 ++++++++ 6 files changed, 26 insertions(+), 2 deletions(-) diff --git a/ols/app/models/config.py b/ols/app/models/config.py index 750e43dd..9ecd83d8 100644 --- a/ols/app/models/config.py +++ b/ols/app/models/config.py @@ -276,6 +276,7 @@ class ProviderConfig(BaseModel): credentials: Optional[str] = None project_id: Optional[str] = None models: dict[str, ModelConfig] = {} + api_version: Optional[str] = None deployment_name: Optional[str] = None openai_config: Optional[OpenAIConfig] = None azure_config: Optional[AzureOpenAIConfig] = None @@ -321,6 +322,9 @@ def __init__( self.setup_models_config(data) if self.type == constants.PROVIDER_AZURE_OPENAI: + self.api_version = data.get( + "api_version", constants.DEFAULT_AZURE_API_VERSION + ) # deployment_name only required when using Azure OpenAI self.deployment_name = data.get("deployment_name", None) # note: it can be overwritten in azure_config diff --git a/ols/constants.py b/ols/constants.py index 824a63c6..4780826a 100644 --- a/ols/constants.py +++ b/ols/constants.py @@ -45,6 +45,8 @@ class QueryValidationMethod(StrEnum): } ) +DEFAULT_AZURE_API_VERSION = "2024-02-15-preview" + # models diff --git a/ols/src/llms/providers/azure_openai.py b/ols/src/llms/providers/azure_openai.py index 467f0ae6..e1b25b46 100644 --- a/ols/src/llms/providers/azure_openai.py +++ b/ols/src/llms/providers/azure_openai.py @@ -48,6 +48,7 @@ def default_params(self) -> dict[str, Any]: """Construct and return structure with default LLM params.""" self.url = str(self.provider_config.url or self.url) self.credentials = self.provider_config.credentials + api_version = self.provider_config.api_version deployment_name = self.provider_config.deployment_name azure_config = self.provider_config.azure_config @@ -60,7 +61,7 @@ def default_params(self) -> dict[str, Any]: default_parameters = { "azure_endpoint": self.url, - "api_version": "2024-02-15-preview", + "api_version": api_version, "deployment_name": deployment_name, "model": self.model, "model_kwargs": { diff --git a/tests/config/valid_config_with_azure_openai_api_version.yaml b/tests/config/valid_config_with_azure_openai_api_version.yaml index 70085d97..52385809 100644 --- a/tests/config/valid_config_with_azure_openai_api_version.yaml +++ b/tests/config/valid_config_with_azure_openai_api_version.yaml @@ -5,7 +5,7 @@ llm_providers: url: "https://url1" deployment_name: "test" credentials_path": tests/config/secret/apitoken - api_version: 2024-12-31 + api_version: "2024-12-31" azure_openai_config: url: "http://localhost:1234" deployment_name: "*deployment name*" diff --git a/tests/unit/app/models/test_config.py b/tests/unit/app/models/test_config.py index a4841d5d..42eefc16 100644 --- a/tests/unit/app/models/test_config.py +++ b/tests/unit/app/models/test_config.py @@ -448,6 +448,9 @@ def test_provider_config_azure_openai_specific(): ], } ) + # Default azure api version + assert provider_config.api_version == constants.DEFAULT_AZURE_API_VERSION + # Azure OpenAI-specific configuration must be present assert provider_config.azure_config is not None assert str(provider_config.azure_config.url) == "http://localhost/" @@ -477,6 +480,7 @@ def test_provider_config_apitoken_only(): "name": "test_name", "type": "azure_openai", "url": "test_url", + "api_version": "2024-02-15", "azure_openai_config": { "url": "http://localhost", "credentials_path": "tests/config/secret/apitoken", @@ -490,6 +494,9 @@ def test_provider_config_apitoken_only(): ], } ) + # Azure version is set from config + assert provider_config.api_version == "2024-02-15" + # Azure OpenAI-specific configuration must be present assert provider_config.azure_config is not None assert str(provider_config.azure_config.url) == "http://localhost/" @@ -778,6 +785,8 @@ def test_provider_config_watsonx_specific(): assert provider_config.openai_config is None assert provider_config.bam_config is None + assert provider_config.api_version is None + def test_provider_config_watsonx_unknown_parameters(): """Test if unknown Watsonx parameters are detected.""" diff --git a/tests/unit/utils/test_config.py b/tests/unit/utils/test_config.py index 5aa120bb..7214dc61 100644 --- a/tests/unit/utils/test_config.py +++ b/tests/unit/utils/test_config.py @@ -1103,6 +1103,10 @@ def test_valid_config_with_azure_openai(): } ) assert config.config == expected_config + assert ( + config.config.llm_providers.providers.get("p1").api_version + == constants.DEFAULT_AZURE_API_VERSION + ) except Exception as e: print(traceback.format_exc()) pytest.fail(f"loading valid configuration failed: {e}") @@ -1210,6 +1214,9 @@ def test_valid_config_with_azure_openai_api_version(): } ) assert config.config == expected_config + assert ( + config.config.llm_providers.providers.get("p1").api_version == "2024-12-31" + ) except Exception as e: print(traceback.format_exc()) pytest.fail(f"loading valid configuration failed: {e}") @@ -1260,6 +1267,7 @@ def test_valid_config_with_bam(): } ) assert config.config == expected_config + assert config.config.llm_providers.providers.get("p1").api_version is None except Exception as e: print(traceback.format_exc()) pytest.fail(f"loading valid configuration failed: {e}")