From bf290be8ab430d6a819c250d5c516edeb36eb7f7 Mon Sep 17 00:00:00 2001 From: valentinfrlch Date: Tue, 11 Feb 2025 14:47:41 +0100 Subject: [PATCH] Use OpenAI provider as endpoint is compatible --- custom_components/llmvision/config_flow.py | 19 ++--- custom_components/llmvision/providers.py | 93 +++------------------- 2 files changed, 21 insertions(+), 91 deletions(-) diff --git a/custom_components/llmvision/config_flow.py b/custom_components/llmvision/config_flow.py index 96ace19..7c1a740 100644 --- a/custom_components/llmvision/config_flow.py +++ b/custom_components/llmvision/config_flow.py @@ -9,7 +9,6 @@ Groq, LocalAI, Ollama, - OpenWebUI, AWSBedrock ) from .const import ( @@ -42,6 +41,7 @@ CONF_OPENWEBUI_HTTPS, CONF_OPENWEBUI_API_KEY, CONF_OPENWEBUI_DEFAULT_MODEL, + ENDPOINT_OPENWEBUI, ) import voluptuous as vol @@ -214,14 +214,15 @@ async def async_step_openwebui(self, user_input=None): # save provider to user_input user_input["provider"] = self.init_info["provider"] try: - openwebui = OpenWebUI(hass=self.hass, - api_key=user_input[CONF_OPENWEBUI_API_KEY], - model=user_input[CONF_OPENWEBUI_DEFAULT_MODEL], - endpoint={ - 'ip_address': user_input[CONF_OPENWEBUI_IP_ADDRESS], - 'port': user_input[CONF_OPENWEBUI_PORT], - 'https': user_input[CONF_OPENWEBUI_HTTPS] - }) + endpoint = ENDPOINT_OPENWEBUI.format( + ip_address=user_input[CONF_OPENWEBUI_IP_ADDRESS], + port=user_input[CONF_OPENWEBUI_PORT], + protocol="https" if user_input[CONF_OPENWEBUI_HTTPS] else "http" + ) + openwebui = OpenAI(hass=self.hass, + api_key=user_input[CONF_OPENWEBUI_API_KEY], + default_model=user_input[CONF_OPENWEBUI_DEFAULT_MODEL], + endpoint={'base_url': endpoint}) await openwebui.validate() # add the mode to user_input if self.source == config_entries.SOURCE_RECONFIGURE: diff --git a/custom_components/llmvision/providers.py b/custom_components/llmvision/providers.py index 4529f58..e2f4126 100644 --- a/custom_components/llmvision/providers.py +++ b/custom_components/llmvision/providers.py @@ -229,11 +229,16 @@ async def call(self, call): api_key = config.get(CONF_OPENWEBUI_API_KEY) default_model = config.get(CONF_OPENWEBUI_DEFAULT_MODEL) - provider_instance = OpenWebUI(hass=self.hass, api_key=api_key, model=default_model, endpoint={ - 'ip_address': ip_address, - 'port': port, - 'https': https - }) + + endpoint = ENDPOINT_OPENWEBUI.format( + ip_address=ip_address, + port=port, + protocol="https" if https else "http" + ) + + provider_instance = OpenAI( + self.hass, api_key=api_key, endpoint={'base_url': endpoint}, default_model=default_model) + else: raise ServiceValidationError("invalid_provider") @@ -884,80 +889,4 @@ async def validate(self) -> None | ServiceValidationError: "messages": [{"role": "user", "content": [{"text": "Hi"}]}], "inferenceConfig": {"maxTokens": 10, "temperature": 0.5} } - await self.invoke_bedrock(model=self.default_model, data=data) - - -class OpenWebUI(Provider): - def __init__(self, hass, api_key, model, endpoint={'ip_address': "0.0.0.0", 'port': "3000", 'https': False}): - super().__init__(hass, api_key, endpoint=endpoint) - self.default_model = model - - def _generate_headers(self) -> dict: - return {'Content-type': 'application/json', - 'Authorization': 'Bearer ' + self.api_key} - - async def _make_request(self, data) -> str: - headers = self._generate_headers() - https = self.endpoint.get("https") - ip_address = self.endpoint.get("ip_address") - port = self.endpoint.get("port") - protocol = "https" if https else "http" - endpoint = ENDPOINT_OPENWEBUI.format( - ip_address=ip_address, - port=port, - protocol=protocol - ) - - response = await self._post(url=endpoint, headers=headers, data=data) - response_text = response.get( - "choices")[0].get("message").get("content") - return response_text - - def _prepare_vision_data(self, call) -> list: - payload = {"model": call.model, - "messages": [{"role": "user", "content": []}], - "max_tokens": call.max_tokens, - "temperature": call.temperature - } - - for image, filename in zip(call.base64_images, call.filenames): - tag = ("Image " + str(call.base64_images.index(image) + 1) - ) if filename == "" else filename - payload["messages"][0]["content"].append( - {"type": "text", "text": tag + ":"}) - payload["messages"][0]["content"].append({"type": "image_url", "image_url": { - "url": f"data:image/jpeg;base64,{image}"}}) - payload["messages"][0]["content"].append( - {"type": "text", "text": call.message}) - return payload - - def _prepare_text_data(self, call) -> list: - return { - "model": call.model, - "messages": [{"role": "user", "content": [{"type": "text", "text": call.message}]}], - "max_tokens": call.max_tokens, - "temperature": call.temperature - } - - async def validate(self) -> None | ServiceValidationError: - if self.api_key: - headers = self._generate_headers() - https = self.endpoint.get("https") - ip_address = self.endpoint.get("ip_address") - port = self.endpoint.get("port") - protocol = "https" if https else "http" - endpoint = ENDPOINT_OPENWEBUI.format( - ip_address=ip_address, - port=port, - protocol=protocol - ) - data = { - "model": self.default_model, - "messages": [{"role": "user", "content": [{"type": "text", "text": "Hi"}]}], - "max_tokens": 1, - "temperature": 0.5 - } - - await self._post(url=endpoint, headers=headers, data=data) - else: - raise ServiceValidationError("empty_api_key") + await self.invoke_bedrock(model=self.default_model, data=data) \ No newline at end of file