Skip to content

Commit

Permalink
set default environment variables
Browse files Browse the repository at this point in the history
  • Loading branch information
AnniePacheco committed Oct 10, 2024
1 parent 5250502 commit 42c7935
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 22 deletions.
52 changes: 42 additions & 10 deletions backend/managers/MessagesManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from backend.managers.RagManager import RagManager
from langchain_community.llms import Ollama
import os
from dotenv import load_dotenv, set_key
from common.paths import base_dir

class MessagesManager:
_instance = None
Expand All @@ -21,16 +23,47 @@ def __new__(cls, *args, **kwargs):
cls._instance = super(MessagesManager, cls).__new__(cls, *args, **kwargs)
return cls._instance

def __init__(self):
def __init__(self, max_tokens=None, temperature=None, top_k=None, top_p=None):
if not hasattr(self, '_initialized'):
with self._lock:
if not hasattr(self, '_initialized'):
self._initialized = True
self.max_tokens = int(os.environ.get('MAX_TOKENS'))
self.temperature = float(os.environ.get('TEMPERATURE'))
self.top_k = int(os.environ.get('TOP_K'))
self.top_p = float(os.environ.get('TOP_P'))
load_dotenv(base_dir / '.env')
self.max_tokens = max_tokens if max_tokens else self.get_model_params('MAX_TOKENS')
self.temperature = temperature if temperature else self.get_model_params('TEMPERATURE')
self.top_k = top_k if top_k else self.get_model_params('TOP_K')
self.top_p = top_p if top_p else self.get_model_params('TOP_P')

def get_model_params(self, param_name: str):
if param_name == 'MAX_TOKENS':
max_tokens=os.environ.get('MAX_TOKENS')
if not max_tokens:
max_tokens = 1000000
set_key(base_dir / '.env', 'MAX_TOKENS', str(max_tokens))
return max_tokens

if param_name == 'TEMPERATURE':
temperature=os.environ.get('TEMPERATURE')
if not temperature:
temperature = 0.2
set_key(base_dir / '.env', 'TEMPERATURE', str(temperature))
return temperature

if param_name == 'TOP_K':
top_k=os.environ.get('TOP_K')
if not top_k:
top_k = 10
set_key(base_dir / '.env', 'TOP_K', str(top_k))
return top_k

if param_name == 'TOP_P':
top_p=os.environ.get('TOP_P')
if not top_p:
top_p = 0.5
print("top_p", str(top_p))
set_key(base_dir / '.env', 'TOP_P', str(top_p))
return top_p

async def __get_llm_name__(self, assistant_id) -> Tuple[Optional[str], Optional[str]]:
async with db_session_context() as session:
result = await session.execute(select(Resource).filter(Resource.id == assistant_id))
Expand All @@ -43,7 +76,6 @@ async def __get_llm_name__(self, assistant_id) -> Tuple[Optional[str], Optional[
llm = result.scalar_one_or_none()
if not llm:
return None, "LLM resource not found"

return self.extract_names_from_uri(llm.uri.split('/')[-1]), None

def set_max_tokens(self, max_tokens: int):
Expand All @@ -69,10 +101,10 @@ async def create_message(self, message_data: MessageCreateSchema) -> Tuple[Optio
return None, error_message

llm = Ollama(model=model_name,
num_predict=self.max_tokens,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p)
num_predict=int(self.max_tokens),
temperature=float(self.temperature),
top_k=int(self.top_k),
top_p=float(self.top_p))

assistant_id = message_data['assistant_id']
query = message_data['prompt']
Expand Down
57 changes: 48 additions & 9 deletions backend/managers/RagManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from enum import Enum
import aiofiles
import asyncio
from dotenv import load_dotenv, set_key
from common.paths import base_dir

logger = logging.getLogger(__name__)

Expand All @@ -47,22 +49,59 @@ def __new__(cls, *args, **kwargs):
cls._instance = super(RagManager, cls).__new__(cls, *args, **kwargs)
return cls._instance

def __init__(self):
def __init__(self, chunk_size=None, chunk_overlap=None, add_start_index=None, embedder_model=None, system_prompt=None):
if not hasattr(self, '_initialized'):
with self._lock:
if not hasattr(self, '_initialized'):
self._initialized = True
load_dotenv(base_dir / '.env')
self.chunk_size = chunk_size if chunk_size else self.get_params('CHUNK_SIZE')
self.chunk_overlap = chunk_overlap if chunk_overlap else self.get_params('CHUNK_OVERLAP')
self.add_start_index = add_start_index if add_start_index else self.get_params('ADD_START_INDEX')
self.embedder_model = embedder_model if embedder_model else self.get_params('EMBEDDER_MODEL')
self.system_prompt = system_prompt if system_prompt else self.get_params('SYSTEM_PROMPT')

def get_params(self, param_name: str):
if param_name == 'CHUNK_SIZE':
chunk_size=os.environ.get('CHUNK_SIZE')
if not chunk_size:
chunk_size = 10000
set_key(base_dir / '.env', 'CHUNK_SIZE', str(chunk_size))
return chunk_size
if param_name == 'CHUNK_OVERLAP':
chunk_overlap=os.environ.get('CHUNK_OVERLAP')
if not chunk_overlap:
chunk_overlap = 200
set_key(base_dir / '.env', 'CHUNK_OVERLAP', str(chunk_overlap))
return chunk_overlap
if param_name == 'ADD_START_INDEX':
add_start_index=os.environ.get('ADD_START_INDEX')
if not add_start_index:
add_start_index = 'True'
set_key(base_dir / '.env', 'ADD_START_INDEX', str(add_start_index))
return add_start_index
if param_name == 'EMBEDDER_MODEL':
embedder_model=os.environ.get('EMBEDDER_MODEL')
if not embedder_model:
embedder_model = 'llama3:latest'
set_key(base_dir / '.env', 'EMBEDDER_MODEL', str(embedder_model))
return embedder_model
if param_name == 'SYSTEM_PROMPT':
system_prompt=os.environ.get('SYSTEM_PROMPT')
if not system_prompt:
system_prompt = "You are an assistant for question-answering tasks.Use the following pieces of retrieved context to answer the question. If you don't know the answer, say that you don't know."
set_key(base_dir / '.env', 'SYSTEM_PROMPT', str(system_prompt))
return system_prompt

async def create_index(self, resource_id: str, path_files: List[str], files_ids:List[str]) -> List[dict]:
loop = asyncio.get_running_loop()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=int(os.environ.get('CHUNK_SIZE')),
chunk_overlap=int(os.environ.get('CHUNK_OVERLAP')),
add_start_index=bool(strtobool(os.environ.get('ADD_START_INDEX')))
chunk_size=int(self.chunk_size),
chunk_overlap=int(self.chunk_overlap),
add_start_index=bool(strtobool(self.add_start_index))
)

file_info_list = []
vectorstore = await self.initialize_chroma(resource_id)
vectorstore = await self.initialize_chroma(resource_id)

for path, file_id in zip(path_files, files_ids):
file_name = Path(path).name
Expand Down Expand Up @@ -149,7 +188,7 @@ async def create_chunk(self, chunk_id:str, page_id: str, file_id: str, assistan


async def initialize_chroma(self, collection_name: str):
embed = OllamaEmbeddings(model=os.environ.get('EMBEDDER_MODEL'))
embed = OllamaEmbeddings(model=self.embedder_model)

path = Path(chroma_db_path)
vectorstore = Chroma(persist_directory=str(path),
Expand All @@ -166,10 +205,10 @@ async def retrieve_and_generate(self, collection_name, query, llm) -> str:
personality_prompt = persona.description
# Combine the system prompt and context
# Combine the system prompt and context
system_prompt = (os.environ.get('SYSTEM_PROMPT') + "\n\n{context}" +
system_prompt = (self.system_prompt + "\n\n{context}" +
"\n\nHere is some information about the assistant expertise to help you answer your questions: " +
personality_prompt)
# system_prompt = (os.environ.get('SYSTEM_PROMPT') +
# system_prompt = (self.system_prompt +
# "\n\n{context}" +
# "\n\nHere is some information about the assistant expertise to help you answer your questions: " + personality_prompt +
# ".\n\nIf the user asks you a question about the assistant information, example: 'What can you tell me about the assistant?', 'What is the name of the assistant?', 'Who is the assistant?'. "+
Expand Down
17 changes: 14 additions & 3 deletions backend/managers/VoicesFacesManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from backend.schemas import VoiceSchema
from pathlib import Path
from backend.models import Resource, Persona
from dotenv import load_dotenv, set_key
from common.paths import base_dir

XI_API_KEY = os.environ.get('XI_API_KEY')

Expand All @@ -24,11 +26,21 @@ def __new__(cls, *args, **kwargs):
cls._instance = super(VoicesFacesManager, cls).__new__(cls, *args, **kwargs)
return cls._instance

def __init__(self):
def __init__(self, xi_chunk_size=None):
if not hasattr(self, '_initialized'):
with self._lock:
if not hasattr(self, '_initialized'):
self._initialized = True
self._initialized = True
load_dotenv(base_dir / '.env')
self.xi_chunk_size = xi_chunk_size if xi_chunk_size else self.get_xi_labs_params('XI_CHUNK_SIZE')

def get_xi_labs_params(self, param_name: str):
if param_name == 'XI_CHUNK_SIZE':
xi_chunk_size=os.environ.get('XI_CHUNK_SIZE')
if not xi_chunk_size:
xi_chunk_size = 1024
set_key(base_dir / '.env', 'XI_CHUNK_SIZE', str(xi_chunk_size))
return xi_chunk_size

async def map_xi_to_voice(self):
xi_voices = []
Expand All @@ -51,7 +63,6 @@ async def map_xi_to_voice(self):
}
xi_voices.append(voice_data)
return xi_voices


async def create_voice(self, voice_data: VoiceCreateSchema) -> str:
async with db_session_context() as session:
Expand Down

0 comments on commit 42c7935

Please sign in to comment.