Skip to content

Commit

Permalink
Use OpenAI provider as endpoint is compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinfrlch committed Feb 11, 2025
1 parent 1890b7c commit bf290be
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 91 deletions.
19 changes: 10 additions & 9 deletions custom_components/llmvision/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Groq,
LocalAI,
Ollama,
OpenWebUI,
AWSBedrock
)
from .const import (
Expand Down Expand Up @@ -42,6 +41,7 @@
CONF_OPENWEBUI_HTTPS,
CONF_OPENWEBUI_API_KEY,
CONF_OPENWEBUI_DEFAULT_MODEL,
ENDPOINT_OPENWEBUI,

)
import voluptuous as vol
Expand Down Expand Up @@ -214,14 +214,15 @@ async def async_step_openwebui(self, user_input=None):
# save provider to user_input
user_input["provider"] = self.init_info["provider"]
try:
openwebui = OpenWebUI(hass=self.hass,
api_key=user_input[CONF_OPENWEBUI_API_KEY],
model=user_input[CONF_OPENWEBUI_DEFAULT_MODEL],
endpoint={
'ip_address': user_input[CONF_OPENWEBUI_IP_ADDRESS],
'port': user_input[CONF_OPENWEBUI_PORT],
'https': user_input[CONF_OPENWEBUI_HTTPS]
})
endpoint = ENDPOINT_OPENWEBUI.format(
ip_address=user_input[CONF_OPENWEBUI_IP_ADDRESS],
port=user_input[CONF_OPENWEBUI_PORT],
protocol="https" if user_input[CONF_OPENWEBUI_HTTPS] else "http"
)
openwebui = OpenAI(hass=self.hass,
api_key=user_input[CONF_OPENWEBUI_API_KEY],
default_model=user_input[CONF_OPENWEBUI_DEFAULT_MODEL],
endpoint={'base_url': endpoint})
await openwebui.validate()
# add the mode to user_input
if self.source == config_entries.SOURCE_RECONFIGURE:
Expand Down
93 changes: 11 additions & 82 deletions custom_components/llmvision/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,16 @@ async def call(self, call):
api_key = config.get(CONF_OPENWEBUI_API_KEY)
default_model = config.get(CONF_OPENWEBUI_DEFAULT_MODEL)

provider_instance = OpenWebUI(hass=self.hass, api_key=api_key, model=default_model, endpoint={
'ip_address': ip_address,
'port': port,
'https': https
})

endpoint = ENDPOINT_OPENWEBUI.format(
ip_address=ip_address,
port=port,
protocol="https" if https else "http"
)

provider_instance = OpenAI(
self.hass, api_key=api_key, endpoint={'base_url': endpoint}, default_model=default_model)


else:
raise ServiceValidationError("invalid_provider")
Expand Down Expand Up @@ -884,80 +889,4 @@ async def validate(self) -> None | ServiceValidationError:
"messages": [{"role": "user", "content": [{"text": "Hi"}]}],
"inferenceConfig": {"maxTokens": 10, "temperature": 0.5}
}
await self.invoke_bedrock(model=self.default_model, data=data)


class OpenWebUI(Provider):
def __init__(self, hass, api_key, model, endpoint={'ip_address': "0.0.0.0", 'port': "3000", 'https': False}):
super().__init__(hass, api_key, endpoint=endpoint)
self.default_model = model

def _generate_headers(self) -> dict:
return {'Content-type': 'application/json',
'Authorization': 'Bearer ' + self.api_key}

async def _make_request(self, data) -> str:
headers = self._generate_headers()
https = self.endpoint.get("https")
ip_address = self.endpoint.get("ip_address")
port = self.endpoint.get("port")
protocol = "https" if https else "http"
endpoint = ENDPOINT_OPENWEBUI.format(
ip_address=ip_address,
port=port,
protocol=protocol
)

response = await self._post(url=endpoint, headers=headers, data=data)
response_text = response.get(
"choices")[0].get("message").get("content")
return response_text

def _prepare_vision_data(self, call) -> list:
payload = {"model": call.model,
"messages": [{"role": "user", "content": []}],
"max_tokens": call.max_tokens,
"temperature": call.temperature
}

for image, filename in zip(call.base64_images, call.filenames):
tag = ("Image " + str(call.base64_images.index(image) + 1)
) if filename == "" else filename
payload["messages"][0]["content"].append(
{"type": "text", "text": tag + ":"})
payload["messages"][0]["content"].append({"type": "image_url", "image_url": {
"url": f"data:image/jpeg;base64,{image}"}})
payload["messages"][0]["content"].append(
{"type": "text", "text": call.message})
return payload

def _prepare_text_data(self, call) -> list:
return {
"model": call.model,
"messages": [{"role": "user", "content": [{"type": "text", "text": call.message}]}],
"max_tokens": call.max_tokens,
"temperature": call.temperature
}

async def validate(self) -> None | ServiceValidationError:
if self.api_key:
headers = self._generate_headers()
https = self.endpoint.get("https")
ip_address = self.endpoint.get("ip_address")
port = self.endpoint.get("port")
protocol = "https" if https else "http"
endpoint = ENDPOINT_OPENWEBUI.format(
ip_address=ip_address,
port=port,
protocol=protocol
)
data = {
"model": self.default_model,
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hi"}]}],
"max_tokens": 1,
"temperature": 0.5
}

await self._post(url=endpoint, headers=headers, data=data)
else:
raise ServiceValidationError("empty_api_key")
await self.invoke_bedrock(model=self.default_model, data=data)

0 comments on commit bf290be

Please sign in to comment.