Skip to content

Commit

Permalink
Merge pull request #23 from tisnik/azure-openai-version
Browse files Browse the repository at this point in the history
config: add api_version required for azure
  • Loading branch information
tisnik authored Oct 4, 2024
2 parents 59c0ec5 + c6d84d2 commit d638aca
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 2 deletions.
4 changes: 4 additions & 0 deletions ols/app/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ class ProviderConfig(BaseModel):
credentials: Optional[str] = None
project_id: Optional[str] = None
models: dict[str, ModelConfig] = {}
api_version: Optional[str] = None
deployment_name: Optional[str] = None
openai_config: Optional[OpenAIConfig] = None
azure_config: Optional[AzureOpenAIConfig] = None
Expand Down Expand Up @@ -321,6 +322,9 @@ def __init__(
self.setup_models_config(data)

if self.type == constants.PROVIDER_AZURE_OPENAI:
self.api_version = data.get(
"api_version", constants.DEFAULT_AZURE_API_VERSION
)
# deployment_name only required when using Azure OpenAI
self.deployment_name = data.get("deployment_name", None)
# note: it can be overwritten in azure_config
Expand Down
2 changes: 2 additions & 0 deletions ols/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class QueryValidationMethod(StrEnum):
}
)

DEFAULT_AZURE_API_VERSION = "2024-02-15-preview"

# models


Expand Down
3 changes: 2 additions & 1 deletion ols/src/llms/providers/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def default_params(self) -> dict[str, Any]:
"""Construct and return structure with default LLM params."""
self.url = str(self.provider_config.url or self.url)
self.credentials = self.provider_config.credentials
api_version = self.provider_config.api_version
deployment_name = self.provider_config.deployment_name
azure_config = self.provider_config.azure_config

Expand All @@ -60,7 +61,7 @@ def default_params(self) -> dict[str, Any]:

default_parameters = {
"azure_endpoint": self.url,
"api_version": "2024-02-15-preview",
"api_version": api_version,
"deployment_name": deployment_name,
"model": self.model,
"model_kwargs": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ llm_providers:
url: "https://url1"
deployment_name: "test"
credentials_path": tests/config/secret/apitoken
api_version: 2024-12-31
api_version: "2024-12-31"
azure_openai_config:
url: "http://localhost:1234"
deployment_name: "*deployment name*"
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/app/models/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,9 @@ def test_provider_config_azure_openai_specific():
],
}
)
# Default azure api version
assert provider_config.api_version == constants.DEFAULT_AZURE_API_VERSION

# Azure OpenAI-specific configuration must be present
assert provider_config.azure_config is not None
assert str(provider_config.azure_config.url) == "http://localhost/"
Expand Down Expand Up @@ -477,6 +480,7 @@ def test_provider_config_apitoken_only():
"name": "test_name",
"type": "azure_openai",
"url": "test_url",
"api_version": "2024-02-15",
"azure_openai_config": {
"url": "http://localhost",
"credentials_path": "tests/config/secret/apitoken",
Expand All @@ -490,6 +494,9 @@ def test_provider_config_apitoken_only():
],
}
)
# Azure version is set from config
assert provider_config.api_version == "2024-02-15"

# Azure OpenAI-specific configuration must be present
assert provider_config.azure_config is not None
assert str(provider_config.azure_config.url) == "http://localhost/"
Expand Down Expand Up @@ -778,6 +785,8 @@ def test_provider_config_watsonx_specific():
assert provider_config.openai_config is None
assert provider_config.bam_config is None

assert provider_config.api_version is None


def test_provider_config_watsonx_unknown_parameters():
"""Test if unknown Watsonx parameters are detected."""
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/utils/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,10 @@ def test_valid_config_with_azure_openai():
}
)
assert config.config == expected_config
assert (
config.config.llm_providers.providers.get("p1").api_version
== constants.DEFAULT_AZURE_API_VERSION
)
except Exception as e:
print(traceback.format_exc())
pytest.fail(f"loading valid configuration failed: {e}")
Expand Down Expand Up @@ -1210,6 +1214,9 @@ def test_valid_config_with_azure_openai_api_version():
}
)
assert config.config == expected_config
assert (
config.config.llm_providers.providers.get("p1").api_version == "2024-12-31"
)
except Exception as e:
print(traceback.format_exc())
pytest.fail(f"loading valid configuration failed: {e}")
Expand Down Expand Up @@ -1260,6 +1267,7 @@ def test_valid_config_with_bam():
}
)
assert config.config == expected_config
assert config.config.llm_providers.providers.get("p1").api_version is None
except Exception as e:
print(traceback.format_exc())
pytest.fail(f"loading valid configuration failed: {e}")
Expand Down

0 comments on commit d638aca

Please sign in to comment.