From dc54913efb326a26fcba21cdc1d09cc1818f719c Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Tue, 3 Sep 2024 19:29:00 +0200 Subject: [PATCH 1/7] feat: inform user about the costs, use logging, allow skipping --- src/rai/rai/cli/rai_cli.py | 102 ++++++++++++++++++++++++++----------- 1 file changed, 71 insertions(+), 31 deletions(-) diff --git a/src/rai/rai/cli/rai_cli.py b/src/rai/rai/cli/rai_cli.py index 02624a650..37b5df9d1 100644 --- a/src/rai/rai/cli/rai_cli.py +++ b/src/rai/rai/cli/rai_cli.py @@ -15,9 +15,11 @@ import argparse import glob +import logging import subprocess from pathlib import Path +import coloredlogs from langchain_community.vectorstores import FAISS from langchain_core.messages import SystemMessage from langchain_openai import ChatOpenAI, OpenAIEmbeddings @@ -26,6 +28,10 @@ from rai.messages import preprocess_image from rai.messages.multimodal import HumanMultimodalMessage +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +coloredlogs.install(level="INFO") # type: ignore + def parse_whoami_package(): parser = argparse.ArgumentParser( @@ -35,42 +41,80 @@ def parse_whoami_package(): "documentation_root", type=str, help="Path to the root of the documentation" ) parser.add_argument( - "output", + "--output", type=str, + required=False, + default=None, help="Path to the output directory", ) args = parser.parse_args() + save_dir = args.output if args.output is not None else args.documentation_root + + def build_docs_vector_store(): + logger.info("Building the robot docs vector store...") + faiss_index = FAISS.from_documents(docs, OpenAIEmbeddings()) + faiss_index.add_documents(docs) + faiss_index.save_local(save_dir) + + def build_robot_identity(): + logger.info("Building the robot identity...") + prompt = ( + "You will be given a robot's documentation. " + "Your task is to identify the robot's identity. " + "The description should cover the most important aspects of the robot with respect to human interaction, " + "as well as the robot's capabilities and limitations including sensor and actuator information. " + "If there are any images provided, make sure to take them into account by thoroughly analyzing them. " + "Your description should be thorough and detailed." + "Your reply should start with I am a ..." + ) + llm = ChatOpenAI(model="gpt-4o-mini") + + images = glob.glob(args.documentation_root + "/images/*") + + messages = [SystemMessage(content=prompt)] + [ + HumanMultimodalMessage( + content=documentation, + images=[preprocess_image(image) for image in images], + ) + ] + output = llm.invoke(messages) + assert isinstance(output.content, str), "Malformed output" + + with open(save_dir + "/robot_identity.txt", "w") as f: + f.write(output.content) + logger.info("Done") docs = ingest_documentation( documentation_root=args.documentation_root + "/documentation" ) - faiss_index = FAISS.from_documents(docs, OpenAIEmbeddings()) - faiss_index.add_documents(docs) - save_dir = args.output - faiss_index.save_local(save_dir) - - prompt = ( - "You will be given a robot's documentation. " - "Your task is to identify the robot's identity. " - "The description should cover the most important aspects of the robot with respect to human interaction, " - "as well as the robot's capabilities and limitations including sensor and actuator information. " - "If there are any images provided, make sure to take them into account by thoroughly analyzing them. " - "Your description should be thorough and detailed." - "Your reply should start with I am a ..." + documentation = str([doc.page_content for doc in docs]) + n_tokens = len(documentation) // 4.0 + logger.info( + f"The documentation's length is {len(documentation)} chars, " + f"approximately {n_tokens} tokens" ) - llm = ChatOpenAI(model="gpt-4o-mini") - - images = glob.glob(args.documentation_root + "/images/*") - - messages = [SystemMessage(content=prompt)] + [ - HumanMultimodalMessage( - content=str([doc.page_content for doc in docs]), - images=[preprocess_image(image) for image in images], + logger.warn( + f"Building the robot docs vector store will cost " + f"approximately {n_tokens / 1_000_000 * 0.1:.4f}$. " + "Do you want to continue? (y/n)" + ) + if input() == "y": + build_docs_vector_store() + else: + logger.info("Skipping the robot docs vector store creation.") + + logger.warn( + f"Building the robot identity will cost " + f"approximately {n_tokens / 1_000_000 * 0.15:.4f}$. " + "Do you want to continue? (y/n)" + ) + if input() == "y": + build_robot_identity() + else: + logger.info( + f"Skipping the robot identity creation. " + f"You can do it manually by creating {save_dir}/robot_identity.txt" ) - ] - output = llm.invoke(messages) - with open(save_dir + "/robot_identity.txt", "w") as f: - f.write(output.content) def create_rai_ws(): @@ -125,7 +169,7 @@ def create_rai_ws(): # TODO: Refactor this hacky solution # NOTE (mkotynia) fast solution, worth refactor in the future and testing if it generic for all setup.py file confgurations -def modify_setup_py(setup_py_path): +def modify_setup_py(setup_py_path: Path): with open(setup_py_path, "r") as file: setup_content = file.read() @@ -157,7 +201,3 @@ def modify_setup_py(setup_py_path): with open(setup_py_path, "w") as file: file.write(modified_script) - - -if __name__ == "__main__": - create_rai_ws() From 619106c8ec07f21c3d745d54e4fbbfcf58d988bf Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Wed, 4 Sep 2024 15:01:21 +0200 Subject: [PATCH 2/7] feat(config): add config.toml file with LLM and embeddings configurations feat(model_initialization): add functions to initialize LLM and embeddings models based on config file --- config.toml | 14 ++++++ src/rai/rai/cli/rai_cli.py | 8 +-- src/rai/rai/utils/model_initialization.py | 59 +++++++++++++++++++++++ 3 files changed, 78 insertions(+), 3 deletions(-) create mode 100644 config.toml create mode 100644 src/rai/rai/utils/model_initialization.py diff --git a/config.toml b/config.toml new file mode 100644 index 000000000..cddf8b14f --- /dev/null +++ b/config.toml @@ -0,0 +1,14 @@ +[llm] +# Choose the LLM provider: openai, aws, or ollama +vendor = "openai" +simple_model = "gpt-4o-mini" # Faster, less expensive model for straightforward tasks +complex_model = "gpt-4o" # More powerful model for complex reasoning or generation +region_name = "us-east-1" # AWS region (only applicable when vendor is "aws") +base_url = "http://localhost:11434" # Ollama API endpoint (only applicable when vendor is "ollama") + +[embeddings] +# Choose the embeddings provider: openai, aws, or ollama +vendor = "openai" +model = "text-embedding-ada-002" # Embedding model to use +region_name = "us-east-1" # AWS region (only applicable when vendor is "aws") +base_url = "http://localhost:11434" # Ollama API endpoint (only applicable when vendor is "ollama") diff --git a/src/rai/rai/cli/rai_cli.py b/src/rai/rai/cli/rai_cli.py index 37b5df9d1..8f77a8235 100644 --- a/src/rai/rai/cli/rai_cli.py +++ b/src/rai/rai/cli/rai_cli.py @@ -22,11 +22,11 @@ import coloredlogs from langchain_community.vectorstores import FAISS from langchain_core.messages import SystemMessage -from langchain_openai import ChatOpenAI, OpenAIEmbeddings from rai.apps.talk_to_docs import ingest_documentation from rai.messages import preprocess_image from rai.messages.multimodal import HumanMultimodalMessage +from rai.utils.model_initialization import get_embeddings_model, get_llm_model logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -50,9 +50,12 @@ def parse_whoami_package(): args = parser.parse_args() save_dir = args.output if args.output is not None else args.documentation_root + llm = get_llm_model(model_type="simple_model") + embeddings_model = get_embeddings_model() + def build_docs_vector_store(): logger.info("Building the robot docs vector store...") - faiss_index = FAISS.from_documents(docs, OpenAIEmbeddings()) + faiss_index = FAISS.from_documents(docs, embeddings_model) faiss_index.add_documents(docs) faiss_index.save_local(save_dir) @@ -67,7 +70,6 @@ def build_robot_identity(): "Your description should be thorough and detailed." "Your reply should start with I am a ..." ) - llm = ChatOpenAI(model="gpt-4o-mini") images = glob.glob(args.documentation_root + "/images/*") diff --git a/src/rai/rai/utils/model_initialization.py b/src/rai/rai/utils/model_initialization.py new file mode 100644 index 000000000..36d008c14 --- /dev/null +++ b/src/rai/rai/utils/model_initialization.py @@ -0,0 +1,59 @@ +from typing import Literal, TypedDict + +import tomli + + +class LLMConfig(TypedDict): + vendor: str + simple_model: str + complex_model: str + region_name: str + base_url: str + + +class EmbeddingsConfig(TypedDict): + vendor: str + model: str + region_name: str + base_url: str + + +def get_llm_model(model_type: Literal["simple_model", "complex_model"]): + rai_config = tomli.load(open("config.toml", "rb")) + llm_config = LLMConfig(**rai_config["llm"]) + if llm_config["vendor"] == "openai": + from langchain_openai import ChatOpenAI + + return ChatOpenAI(model=llm_config[model_type]) + elif llm_config["vendor"] == "aws": + from langchain_aws import ChatBedrock + + return ChatBedrock(model_id=llm_config[model_type], region_name=llm_config["region_name"]) # type: ignore + elif llm_config["vendor"] == "ollama": + from langchain_ollama import ChatOllama + + return ChatOllama(model=llm_config[model_type], base_url=llm_config["base_url"]) + else: + raise ValueError(f"Unknown LLM vendor: {llm_config['vendor']}") + + +def get_embeddings_model(): + rai_config = tomli.load(open("config.toml", "rb")) + embeddings_config = EmbeddingsConfig(**rai_config["embeddings"]) + + if embeddings_config["vendor"] == "openai": + from langchain_openai import OpenAIEmbeddings + + return OpenAIEmbeddings() + elif embeddings_config["vendor"] == "aws": + from langchain_aws import BedrockEmbeddings + + return BedrockEmbeddings(model_id=embeddings_config["model"], region_name=embeddings_config["region_name"]) # type: ignore + elif embeddings_config["vendor"] == "ollama": + from langchain_ollama import OllamaEmbeddings + + return OllamaEmbeddings( + model=embeddings_config["model"], base_url=embeddings_config["base_url"] + ) + else: + raise ValueError(f"Unknown embeddings vendor: {embeddings_config['vendor']}") From 949f33e547dcc23c5f4437480430feb16ed44645 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Wed, 4 Sep 2024 15:09:39 +0200 Subject: [PATCH 3/7] chore: remove cost calculation, enahnce logging MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Bartłomiej Boczek --- src/rai/rai/cli/rai_cli.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/rai/rai/cli/rai_cli.py b/src/rai/rai/cli/rai_cli.py index 8f77a8235..e3b6ca796 100644 --- a/src/rai/rai/cli/rai_cli.py +++ b/src/rai/rai/cli/rai_cli.py @@ -92,24 +92,21 @@ def build_robot_identity(): documentation = str([doc.page_content for doc in docs]) n_tokens = len(documentation) // 4.0 logger.info( + "Building the robot docs vector store... " f"The documentation's length is {len(documentation)} chars, " f"approximately {n_tokens} tokens" ) - logger.warn( - f"Building the robot docs vector store will cost " - f"approximately {n_tokens / 1_000_000 * 0.1:.4f}$. " - "Do you want to continue? (y/n)" - ) + logger.warn("Do you want to continue? (y/n)") if input() == "y": build_docs_vector_store() else: logger.info("Skipping the robot docs vector store creation.") - logger.warn( - f"Building the robot identity will cost " - f"approximately {n_tokens / 1_000_000 * 0.15:.4f}$. " - "Do you want to continue? (y/n)" + logger.info( + "Building the robot identity... " + "You can do it manually by creating {save_dir}/robot_identity.txt " ) + logger.warn("Do you want to continue? (y/n)") if input() == "y": build_robot_identity() else: From 3c830a5190f16ad94fc1272d42a5122e28c97109 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Wed, 4 Sep 2024 15:35:53 +0200 Subject: [PATCH 4/7] docs(utils): add Apache License 2.0 header to new __init__.py and model_initialization.py files for legal compliance and clarity --- src/rai/rai/utils/__init__.py | 13 +++++++++++++ src/rai/rai/utils/model_initialization.py | 14 ++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 src/rai/rai/utils/__init__.py diff --git a/src/rai/rai/utils/__init__.py b/src/rai/rai/utils/__init__.py new file mode 100644 index 000000000..ef74fc891 --- /dev/null +++ b/src/rai/rai/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/rai/rai/utils/model_initialization.py b/src/rai/rai/utils/model_initialization.py index 36d008c14..355a0e6a0 100644 --- a/src/rai/rai/utils/model_initialization.py +++ b/src/rai/rai/utils/model_initialization.py @@ -1,3 +1,17 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Literal, TypedDict import tomli From cd9f1b1cf5047c088068dc50f1cc11a257e819a5 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Wed, 4 Sep 2024 20:36:18 +0200 Subject: [PATCH 5/7] build: add tomli --- poetry.lock | 33 +++++---------------------------- pyproject.toml | 1 + 2 files changed, 6 insertions(+), 28 deletions(-) diff --git a/poetry.lock b/poetry.lock index c3fd44c59..2f0380e40 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "addict" @@ -2851,9 +2851,9 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] @@ -2875,9 +2875,9 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] @@ -3025,8 +3025,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" @@ -4442,50 +4442,27 @@ description = "Database Abstraction Library" optional = false python-versions = ">=3.7" files = [ - {file = "SQLAlchemy-2.0.32-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0c9045ecc2e4db59bfc97b20516dfdf8e41d910ac6fb667ebd3a79ea54084619"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1467940318e4a860afd546ef61fefb98a14d935cd6817ed07a228c7f7c62f389"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5954463675cb15db8d4b521f3566a017c8789222b8316b1e6934c811018ee08b"}, {file = "SQLAlchemy-2.0.32-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:167e7497035c303ae50651b351c28dc22a40bb98fbdb8468cdc971821b1ae533"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b27dfb676ac02529fb6e343b3a482303f16e6bc3a4d868b73935b8792edb52d0"}, {file = "SQLAlchemy-2.0.32-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bf2360a5e0f7bd75fa80431bf8ebcfb920c9f885e7956c7efde89031695cafb8"}, {file = "SQLAlchemy-2.0.32-cp310-cp310-win32.whl", hash = "sha256:306fe44e754a91cd9d600a6b070c1f2fadbb4a1a257b8781ccf33c7067fd3e4d"}, {file = "SQLAlchemy-2.0.32-cp310-cp310-win_amd64.whl", hash = "sha256:99db65e6f3ab42e06c318f15c98f59a436f1c78179e6a6f40f529c8cc7100b22"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:21b053be28a8a414f2ddd401f1be8361e41032d2ef5884b2f31d31cb723e559f"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b178e875a7a25b5938b53b006598ee7645172fccafe1c291a706e93f48499ff5"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:723a40ee2cc7ea653645bd4cf024326dea2076673fc9d3d33f20f6c81db83e1d"}, {file = "SQLAlchemy-2.0.32-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:295ff8689544f7ee7e819529633d058bd458c1fd7f7e3eebd0f9268ebc56c2a0"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:49496b68cd190a147118af585173ee624114dfb2e0297558c460ad7495f9dfe2"}, {file = "SQLAlchemy-2.0.32-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:acd9b73c5c15f0ec5ce18128b1fe9157ddd0044abc373e6ecd5ba376a7e5d961"}, {file = "SQLAlchemy-2.0.32-cp311-cp311-win32.whl", hash = "sha256:9365a3da32dabd3e69e06b972b1ffb0c89668994c7e8e75ce21d3e5e69ddef28"}, {file = "SQLAlchemy-2.0.32-cp311-cp311-win_amd64.whl", hash = "sha256:8bd63d051f4f313b102a2af1cbc8b80f061bf78f3d5bd0843ff70b5859e27924"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6bab3db192a0c35e3c9d1560eb8332463e29e5507dbd822e29a0a3c48c0a8d92"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:19d98f4f58b13900d8dec4ed09dd09ef292208ee44cc9c2fe01c1f0a2fe440e9"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cd33c61513cb1b7371fd40cf221256456d26a56284e7d19d1f0b9f1eb7dd7e8"}, {file = "SQLAlchemy-2.0.32-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d6ba0497c1d066dd004e0f02a92426ca2df20fac08728d03f67f6960271feec"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2b6be53e4fde0065524f1a0a7929b10e9280987b320716c1509478b712a7688c"}, {file = "SQLAlchemy-2.0.32-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:916a798f62f410c0b80b63683c8061f5ebe237b0f4ad778739304253353bc1cb"}, {file = "SQLAlchemy-2.0.32-cp312-cp312-win32.whl", hash = "sha256:31983018b74908ebc6c996a16ad3690301a23befb643093fcfe85efd292e384d"}, {file = "SQLAlchemy-2.0.32-cp312-cp312-win_amd64.whl", hash = "sha256:4363ed245a6231f2e2957cccdda3c776265a75851f4753c60f3004b90e69bfeb"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b8afd5b26570bf41c35c0121801479958b4446751a3971fb9a480c1afd85558e"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c750987fc876813f27b60d619b987b057eb4896b81117f73bb8d9918c14f1cad"}, {file = "SQLAlchemy-2.0.32-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ada0102afff4890f651ed91120c1120065663506b760da4e7823913ebd3258be"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:78c03d0f8a5ab4f3034c0e8482cfcc415a3ec6193491cfa1c643ed707d476f16"}, {file = "SQLAlchemy-2.0.32-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:3bd1cae7519283ff525e64645ebd7a3e0283f3c038f461ecc1c7b040a0c932a1"}, {file = "SQLAlchemy-2.0.32-cp37-cp37m-win32.whl", hash = "sha256:01438ebcdc566d58c93af0171c74ec28efe6a29184b773e378a385e6215389da"}, {file = "SQLAlchemy-2.0.32-cp37-cp37m-win_amd64.whl", hash = "sha256:4979dc80fbbc9d2ef569e71e0896990bc94df2b9fdbd878290bd129b65ab579c"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c742be912f57586ac43af38b3848f7688863a403dfb220193a882ea60e1ec3a"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:62e23d0ac103bcf1c5555b6c88c114089587bc64d048fef5bbdb58dfd26f96da"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:251f0d1108aab8ea7b9aadbd07fb47fb8e3a5838dde34aa95a3349876b5a1f1d"}, {file = "SQLAlchemy-2.0.32-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ef18a84e5116340e38eca3e7f9eeaaef62738891422e7c2a0b80feab165905f"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:3eb6a97a1d39976f360b10ff208c73afb6a4de86dd2a6212ddf65c4a6a2347d5"}, {file = "SQLAlchemy-2.0.32-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0c1c9b673d21477cec17ab10bc4decb1322843ba35b481585facd88203754fc5"}, {file = "SQLAlchemy-2.0.32-cp38-cp38-win32.whl", hash = "sha256:c41a2b9ca80ee555decc605bd3c4520cc6fef9abde8fd66b1cf65126a6922d65"}, {file = "SQLAlchemy-2.0.32-cp38-cp38-win_amd64.whl", hash = "sha256:8a37e4d265033c897892279e8adf505c8b6b4075f2b40d77afb31f7185cd6ecd"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:52fec964fba2ef46476312a03ec8c425956b05c20220a1a03703537824b5e8e1"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:328429aecaba2aee3d71e11f2477c14eec5990fb6d0e884107935f7fb6001632"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85a01b5599e790e76ac3fe3aa2f26e1feba56270023d6afd5550ed63c68552b3"}, {file = "SQLAlchemy-2.0.32-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aaf04784797dcdf4c0aa952c8d234fa01974c4729db55c45732520ce12dd95b4"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:4488120becf9b71b3ac718f4138269a6be99a42fe023ec457896ba4f80749525"}, {file = "SQLAlchemy-2.0.32-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:14e09e083a5796d513918a66f3d6aedbc131e39e80875afe81d98a03312889e6"}, {file = "SQLAlchemy-2.0.32-cp39-cp39-win32.whl", hash = "sha256:0d322cc9c9b2154ba7e82f7bf25ecc7c36fbe2d82e2933b3642fc095a52cfc78"}, {file = "SQLAlchemy-2.0.32-cp39-cp39-win_amd64.whl", hash = "sha256:7dd8583df2f98dea28b5cd53a1beac963f4f9d087888d75f22fcc93a07cf8d84"}, @@ -5757,4 +5734,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10, <3.13" -content-hash = "d1bc3be63a03d89358a0b71085cc08bbd303f04a154e5ce4588e35269b9d82fc" +content-hash = "8ebb3e24e7c48e61973a50c638ce35d9821bb56f52541c5c734be5b405f1f0cf" diff --git a/pyproject.toml b/pyproject.toml index ce932693f..55e64751a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ langchain-ollama = "^0.1.1" streamlit = "^1.37.1" deprecated = "^1.2.14" +tomli = "^2.0.1" [tool.poetry.group.dev.dependencies] ipykernel = "^6.29.4" From 75b65a1ed2564c1e8f200f8622e80f2ffdb3b7eb Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Wed, 4 Sep 2024 20:47:07 +0200 Subject: [PATCH 6/7] refactor: simplify config, use dataclasses --- config.toml | 31 ++++---- src/rai/rai/utils/model_initialization.py | 89 ++++++++++++++++------- 2 files changed, 81 insertions(+), 39 deletions(-) diff --git a/config.toml b/config.toml index cddf8b14f..579bbe5e2 100644 --- a/config.toml +++ b/config.toml @@ -1,14 +1,19 @@ -[llm] -# Choose the LLM provider: openai, aws, or ollama -vendor = "openai" -simple_model = "gpt-4o-mini" # Faster, less expensive model for straightforward tasks -complex_model = "gpt-4o" # More powerful model for complex reasoning or generation -region_name = "us-east-1" # AWS region (only applicable when vendor is "aws") -base_url = "http://localhost:11434" # Ollama API endpoint (only applicable when vendor is "ollama") +[vendor] +name = "openai" # openai, aws, ollama -[embeddings] -# Choose the embeddings provider: openai, aws, or ollama -vendor = "openai" -model = "text-embedding-ada-002" # Embedding model to use -region_name = "us-east-1" # AWS region (only applicable when vendor is "aws") -base_url = "http://localhost:11434" # Ollama API endpoint (only applicable when vendor is "ollama") +[aws] +simple_model = "anthropic.claude-3-haiku-20240307-v1:0" +complex_model = "anthropic.claude-3-5-sonnet-20240620-v1:0" +embeddings_model = "amazon.titan-embed-text-v1" +region_name = "us-east-1" + +[openai] +simple_model = "gpt-4o-mini" +complex_model = "gpt-4o" +embeddings_model = "text-embedding-ada-002" + +[ollama] +simple_model = "llama3.1" +complex_model = "llama3.1:70b" +embeddings_model = "llama3.1" +base_url = "http://localhost:11434" diff --git a/src/rai/rai/utils/model_initialization.py b/src/rai/rai/utils/model_initialization.py index 355a0e6a0..bdeede0ca 100644 --- a/src/rai/rai/utils/model_initialization.py +++ b/src/rai/rai/utils/model_initialization.py @@ -12,62 +12,99 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal, TypedDict +from dataclasses import dataclass +from typing import Literal import tomli -class LLMConfig(TypedDict): - vendor: str +@dataclass +class VendorConfig: + name: str + + +@dataclass +class ModelConfig: simple_model: str complex_model: str - region_name: str - base_url: str + embeddings_model: str -class EmbeddingsConfig(TypedDict): - vendor: str - model: str +@dataclass +class AWSConfig(ModelConfig): region_name: str + + +@dataclass +class OllamaConfig(ModelConfig): base_url: str +@dataclass +class RAIConfig: + vendor: VendorConfig + aws: AWSConfig + openai: ModelConfig + ollama: OllamaConfig + + +def load_config() -> RAIConfig: + with open("config.toml", "rb") as f: + config_dict = tomli.load(f) + return RAIConfig( + vendor=VendorConfig(**config_dict["vendor"]), + aws=AWSConfig(**config_dict["aws"]), + openai=ModelConfig(**config_dict["openai"]), + ollama=OllamaConfig(**config_dict["ollama"]), + ) + + def get_llm_model(model_type: Literal["simple_model", "complex_model"]): - rai_config = tomli.load(open("config.toml", "rb")) - llm_config = LLMConfig(**rai_config["llm"]) - if llm_config["vendor"] == "openai": + config = load_config() + vendor = config.vendor.name + model_config = getattr(config, vendor) + + if vendor == "openai": from langchain_openai import ChatOpenAI - return ChatOpenAI(model=llm_config[model_type]) - elif llm_config["vendor"] == "aws": + return ChatOpenAI(model=getattr(model_config, model_type)) + elif vendor == "aws": from langchain_aws import ChatBedrock - return ChatBedrock(model_id=llm_config[model_type], region_name=llm_config["region_name"]) # type: ignore - elif llm_config["vendor"] == "ollama": + return ChatBedrock( + model_id=getattr(model_config, model_type), + region_name=model_config.region_name, + ) + elif vendor == "ollama": from langchain_ollama import ChatOllama - return ChatOllama(model=llm_config[model_type], base_url=llm_config["base_url"]) + return ChatOllama( + model=getattr(model_config, model_type), base_url=model_config.base_url + ) else: - raise ValueError(f"Unknown LLM vendor: {llm_config['vendor']}") + raise ValueError(f"Unknown LLM vendor: {vendor}") def get_embeddings_model(): - rai_config = tomli.load(open("config.toml", "rb")) - embeddings_config = EmbeddingsConfig(**rai_config["embeddings"]) + config = load_config() + vendor = config.vendor.name + model_config = getattr(config, vendor) - if embeddings_config["vendor"] == "openai": + if vendor == "openai": from langchain_openai import OpenAIEmbeddings - return OpenAIEmbeddings() - elif embeddings_config["vendor"] == "aws": + return OpenAIEmbeddings(model=model_config.embeddings_model) + elif vendor == "aws": from langchain_aws import BedrockEmbeddings - return BedrockEmbeddings(model_id=embeddings_config["model"], region_name=embeddings_config["region_name"]) # type: ignore - elif embeddings_config["vendor"] == "ollama": + return BedrockEmbeddings( + model_id=model_config.embeddings_model, region_name=model_config.region_name + ) + elif vendor == "ollama": from langchain_ollama import OllamaEmbeddings return OllamaEmbeddings( - model=embeddings_config["model"], base_url=embeddings_config["base_url"] + model=model_config.embeddings_model, base_url=model_config.base_url ) else: - raise ValueError(f"Unknown embeddings vendor: {embeddings_config['vendor']}") + raise ValueError(f"Unknown embeddings vendor: {vendor}") From d86f67ba396e4aadc239417d3a452da987611775 Mon Sep 17 00:00:00 2001 From: Maciej Majek <46171033+maciejmajek@users.noreply.github.com> Date: Thu, 5 Sep 2024 14:03:03 +0200 Subject: [PATCH 7/7] fix: f string Co-authored-by: Bartek Boczek <22739059+boczekbartek@users.noreply.github.com> --- src/rai/rai/cli/rai_cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rai/rai/cli/rai_cli.py b/src/rai/rai/cli/rai_cli.py index e3b6ca796..a185fab12 100644 --- a/src/rai/rai/cli/rai_cli.py +++ b/src/rai/rai/cli/rai_cli.py @@ -104,7 +104,7 @@ def build_robot_identity(): logger.info( "Building the robot identity... " - "You can do it manually by creating {save_dir}/robot_identity.txt " + f"You can do it manually by creating {save_dir}/robot_identity.txt " ) logger.warn("Do you want to continue? (y/n)") if input() == "y":