Skip to content

Commit

Permalink
Adapt code to serve local models through fastchat. For local models u…
Browse files Browse the repository at this point in the history
…sed non-streaming due to strange errors
  • Loading branch information
Odrec committed Jul 4, 2024
1 parent 11757d4 commit e437cd8
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 85 deletions.
104 changes: 50 additions & 54 deletions src/ai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import re

import openai
from openai import OpenAI

import streamlit as st
Expand All @@ -18,14 +19,15 @@ def load_openai_data():
return OpenAI()

self.client = load_openai_data()
self.is_openai = True
else:
pass
# Add here client initialization for local model
# For example using OpenLLM as inference server
# session_state["inference_server_url"] = "http://localhost:3000/v1"
# self.client = OpenAI(base_url=session_state['inference_server_url'])
# models = session_state["client"].models.list()
# self.model = models.data[0].id
@st.cache_resource
def load_fastchat_endpoint():
return os.getenv("FASTCHAT_LLM_ENDPOINT")

openai.base_url = load_fastchat_endpoint()
self.client = openai
self.is_openai = False

@staticmethod
def get_accessible_models() -> list:
Expand All @@ -37,7 +39,6 @@ def get_accessible_models() -> list:
if 'accessible_models' not in session_state and {'USER_ROLES', 'MODELS_PER_ROLE'} <= os.environ.keys():
user_roles = json.loads(os.environ['USER_ROLES'])
user_role = user_roles.get(session_state.get('username'))

models_per_role = json.loads(os.environ['MODELS_PER_ROLE'])
session_state['accessible_models'] = models_per_role.get(user_role, default_models['models'])

Expand All @@ -47,6 +48,12 @@ def get_accessible_models() -> list:

return session_state['accessible_models']

@staticmethod
def get_fastchat_models() -> list:
openai.base_url = os.getenv("FASTCHAT_LLM_ENDPOINT")
models = openai.models.list()
return [m.id for m in models]

@staticmethod
def _generate_response(stream):
"""
Expand All @@ -60,54 +67,27 @@ def _generate_response(stream):
chunk_content = chunk.choices[0].delta.content
if isinstance(chunk_content, str):
yield chunk_content
else:
continue

def _concatenate_partial_response(self):
"""
Concatenates the partial response into a single string.
"""
# The concatenated response.
str_response = ""
for i in self.partial_response:
if isinstance(i, str):
str_response += i

str_response = "".join(i for i in self.partial_response if isinstance(i, str))
replacements = {
r'\\\s*\(': r'$',
r'\\\s*\)': r'$',
r'\\\s*\[': r'$$',
r'\\\s*\]': r'$$'
}

# Perform the replacements
for pattern, replacement in replacements.items():
str_response = re.sub(pattern, replacement, str_response)

st.markdown(str_response)

self.partial_response = []
self.response += str_response

def get_response(self, prompt, description_to_use):
"""
Sends a prompt to the OpenAI API and returns the API's response.
Parameters:
prompt (str): The user's message or question.
description_to_use (str): Additional context or instructions to provide to the model.
Returns:
str: The response from the chatbot.
"""
#try:
# Prepare the full prompt and messages with context or instructions
messages = self._prepare_full_prompt_and_messages(prompt, description_to_use)

# Send the request to the OpenAI API
# Display assistant response in chat message container
def _handle_openai_response(self, messages):
self.response = ""
# true if the response contains a special text like code block or math expression
self.special_text = False
with st.chat_message("assistant"):
with st.spinner(session_state['_']("Generating response...")):
Expand All @@ -117,11 +97,8 @@ def get_response(self, prompt, description_to_use):
stream=True,
)
self.partial_response = []

gen_stream = self._generate_response(stream)
for chunk_content in gen_stream:
# check if the chunk is a code block
# check if the chunk is a code block
if chunk_content == '```':
self._concatenate_partial_response()
self.partial_response.append(chunk_content)
Expand All @@ -131,29 +108,48 @@ def get_response(self, prompt, description_to_use):
chunk_content = next(gen_stream)
self.partial_response.append(chunk_content)
if chunk_content == "`\n\n":
# show partial response to the user and keep it for later use
self._concatenate_partial_response()
self.special_text = False
except StopIteration:
break

else:
# If the chunk is not a code or math block, append it to the partial response
self.partial_response.append(chunk_content)
if chunk_content:
if '\n' in chunk_content:
self._concatenate_partial_response()
if chunk_content and '\n' in chunk_content:
self._concatenate_partial_response()
if self.partial_response:
self._concatenate_partial_response()

# If there is a partial response left, concatenate it and render it
if self.partial_response:
self._concatenate_partial_response()
def _handle_non_openai_response(self, messages):
self.response = ""
self.special_text = False
with st.chat_message("assistant"):
with st.spinner(session_state['_']("Generating response...")):
response = self.client.chat.completions.create(
model=session_state['selected_model'],
messages=messages,
)

return self.response
# Assuming the FastChat response is different, adjust accordingly
if hasattr(response, 'choices') and response.choices:
if hasattr(response.choices[0], 'message'):
self.response = response.choices[
0].message.content # Adjust based on the actual response structure
else:
self.response = response.choices[0].text # or whatever the attribute structure is
elif isinstance(response, dict): # In case response is a dict
self.response = response.get('choices', [])[0].get('message', {}).get('content', '')

st.markdown(self.response)

def get_response(self, prompt, description_to_use):
messages = self._prepare_full_prompt_and_messages(prompt, description_to_use)

if self.is_openai:
self._handle_openai_response(messages)
else:
self._handle_non_openai_response(messages)

#except Exception as e:
# print(f"An error occurred while fetching the OpenAI response: {e}")
# Return a default error message
return session_state['_']("Sorry, I couldn't process that request.")
return self.response

@staticmethod
def _prepare_full_prompt_and_messages(user_prompt, description_to_use):
Expand Down
69 changes: 38 additions & 31 deletions src/sidebar_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,37 +265,44 @@ def _display_model_information(self):
identify the active model configuration.
"""
with st.sidebar:
if session_state['model_selection'] == 'OpenAI':
accessible_models = AIClient.get_accessible_models()

# Show dropdown when multiple models are available
if accessible_models:
model_label = session_state['_']("Model:")
index = 0
if self.advanced_model in accessible_models: # Use most advanced model as default
index = accessible_models.index(self.advanced_model)
st.selectbox(model_label,
accessible_models,
index=index,
key='selected_model')

# If the model is changed to one that doesn't support images
# we have to clear the widgets from images. For the upload widget,
# the way to do this is by generating a new widget with a new key
# as described in:
# https://discuss.streamlit.io/t/
# are-there-any-ways-to-clear-file-uploader-values-without-using-streamlit-form/40903
if session_state['selected_model'] != self.advanced_model and (
session_state['image_urls'] or
session_state['uploaded_images'] or
session_state['photo_to_use']):
session_state['images_key'] += 1
session_state['image_urls'] = []
session_state['uploaded_images'] = []
session_state['image_content'] = []
session_state['photo_to_use'] = []
session_state['activate_camera'] = False
st.rerun()
accessible_models = AIClient.get_accessible_models()
fastchat_models = AIClient.get_fastchat_models()

all_accessible_models = accessible_models + fastchat_models

# Show dropdown when multiple models are available
if all_accessible_models:
model_label = session_state['_']("Model:")
index = 0
if self.advanced_model in all_accessible_models: # Use most advanced model as default
index = accessible_models.index(self.advanced_model)
st.selectbox(model_label,
all_accessible_models,
index=index,
key='selected_model')

if session_state['selected_model'] in fastchat_models:
session_state['model_selection'] = "OTHER"
else:
session_state['model_selection'] = "OpenAI"

# If the model is changed to one that doesn't support images
# we have to clear the widgets from images. For the upload widget,
# the way to do this is by generating a new widget with a new key
# as described in:
# https://discuss.streamlit.io/t/
# are-there-any-ways-to-clear-file-uploader-values-without-using-streamlit-form/40903
if session_state['selected_model'] != self.advanced_model and (
session_state['image_urls'] or
session_state['uploaded_images'] or
session_state['photo_to_use']):
session_state['images_key'] += 1
session_state['image_urls'] = []
session_state['uploaded_images'] = []
session_state['image_content'] = []
session_state['photo_to_use'] = []
session_state['activate_camera'] = False
st.rerun()

def _display_delete_conversation_button(self, container):
"""
Expand Down

0 comments on commit e437cd8

Please sign in to comment.