diff --git a/config.toml b/config.toml new file mode 100644 index 000000000..579bbe5e2 --- /dev/null +++ b/config.toml @@ -0,0 +1,19 @@ +[vendor] +name = "openai" # openai, aws, ollama + +[aws] +simple_model = "anthropic.claude-3-haiku-20240307-v1:0" +complex_model = "anthropic.claude-3-5-sonnet-20240620-v1:0" +embeddings_model = "amazon.titan-embed-text-v1" +region_name = "us-east-1" + +[openai] +simple_model = "gpt-4o-mini" +complex_model = "gpt-4o" +embeddings_model = "text-embedding-ada-002" + +[ollama] +simple_model = "llama3.1" +complex_model = "llama3.1:70b" +embeddings_model = "llama3.1" +base_url = "http://localhost:11434" diff --git a/poetry.lock b/poetry.lock index c3fd44c59..2f0380e40 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "addict" @@ -2851,9 +2851,9 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] @@ -2875,9 +2875,9 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] @@ -3025,8 +3025,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" @@ -4442,50 +4442,27 @@ description = "Database Abstraction Library" optional = false python-versions = ">=3.7" files = [ - {file = "SQLAlchemy-2.0.32-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0c9045ecc2e4db59bfc97b20516dfdf8e41d910ac6fb667ebd3a79ea54084619"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1467940318e4a860afd546ef61fefb98a14d935cd6817ed07a228c7f7c62f389"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5954463675cb15db8d4b521f3566a017c8789222b8316b1e6934c811018ee08b"}, {file = "SQLAlchemy-2.0.32-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:167e7497035c303ae50651b351c28dc22a40bb98fbdb8468cdc971821b1ae533"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b27dfb676ac02529fb6e343b3a482303f16e6bc3a4d868b73935b8792edb52d0"}, {file = "SQLAlchemy-2.0.32-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bf2360a5e0f7bd75fa80431bf8ebcfb920c9f885e7956c7efde89031695cafb8"}, {file = "SQLAlchemy-2.0.32-cp310-cp310-win32.whl", hash = "sha256:306fe44e754a91cd9d600a6b070c1f2fadbb4a1a257b8781ccf33c7067fd3e4d"}, {file = "SQLAlchemy-2.0.32-cp310-cp310-win_amd64.whl", hash = "sha256:99db65e6f3ab42e06c318f15c98f59a436f1c78179e6a6f40f529c8cc7100b22"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:21b053be28a8a414f2ddd401f1be8361e41032d2ef5884b2f31d31cb723e559f"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b178e875a7a25b5938b53b006598ee7645172fccafe1c291a706e93f48499ff5"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:723a40ee2cc7ea653645bd4cf024326dea2076673fc9d3d33f20f6c81db83e1d"}, {file = "SQLAlchemy-2.0.32-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:295ff8689544f7ee7e819529633d058bd458c1fd7f7e3eebd0f9268ebc56c2a0"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:49496b68cd190a147118af585173ee624114dfb2e0297558c460ad7495f9dfe2"}, {file = "SQLAlchemy-2.0.32-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:acd9b73c5c15f0ec5ce18128b1fe9157ddd0044abc373e6ecd5ba376a7e5d961"}, {file = "SQLAlchemy-2.0.32-cp311-cp311-win32.whl", hash = "sha256:9365a3da32dabd3e69e06b972b1ffb0c89668994c7e8e75ce21d3e5e69ddef28"}, {file = "SQLAlchemy-2.0.32-cp311-cp311-win_amd64.whl", hash = "sha256:8bd63d051f4f313b102a2af1cbc8b80f061bf78f3d5bd0843ff70b5859e27924"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6bab3db192a0c35e3c9d1560eb8332463e29e5507dbd822e29a0a3c48c0a8d92"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:19d98f4f58b13900d8dec4ed09dd09ef292208ee44cc9c2fe01c1f0a2fe440e9"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cd33c61513cb1b7371fd40cf221256456d26a56284e7d19d1f0b9f1eb7dd7e8"}, {file = "SQLAlchemy-2.0.32-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d6ba0497c1d066dd004e0f02a92426ca2df20fac08728d03f67f6960271feec"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2b6be53e4fde0065524f1a0a7929b10e9280987b320716c1509478b712a7688c"}, {file = "SQLAlchemy-2.0.32-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:916a798f62f410c0b80b63683c8061f5ebe237b0f4ad778739304253353bc1cb"}, {file = "SQLAlchemy-2.0.32-cp312-cp312-win32.whl", hash = "sha256:31983018b74908ebc6c996a16ad3690301a23befb643093fcfe85efd292e384d"}, {file = "SQLAlchemy-2.0.32-cp312-cp312-win_amd64.whl", hash = "sha256:4363ed245a6231f2e2957cccdda3c776265a75851f4753c60f3004b90e69bfeb"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b8afd5b26570bf41c35c0121801479958b4446751a3971fb9a480c1afd85558e"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c750987fc876813f27b60d619b987b057eb4896b81117f73bb8d9918c14f1cad"}, {file = "SQLAlchemy-2.0.32-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ada0102afff4890f651ed91120c1120065663506b760da4e7823913ebd3258be"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:78c03d0f8a5ab4f3034c0e8482cfcc415a3ec6193491cfa1c643ed707d476f16"}, {file = "SQLAlchemy-2.0.32-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:3bd1cae7519283ff525e64645ebd7a3e0283f3c038f461ecc1c7b040a0c932a1"}, {file = "SQLAlchemy-2.0.32-cp37-cp37m-win32.whl", hash = "sha256:01438ebcdc566d58c93af0171c74ec28efe6a29184b773e378a385e6215389da"}, {file = "SQLAlchemy-2.0.32-cp37-cp37m-win_amd64.whl", hash = "sha256:4979dc80fbbc9d2ef569e71e0896990bc94df2b9fdbd878290bd129b65ab579c"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c742be912f57586ac43af38b3848f7688863a403dfb220193a882ea60e1ec3a"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:62e23d0ac103bcf1c5555b6c88c114089587bc64d048fef5bbdb58dfd26f96da"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:251f0d1108aab8ea7b9aadbd07fb47fb8e3a5838dde34aa95a3349876b5a1f1d"}, {file = "SQLAlchemy-2.0.32-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ef18a84e5116340e38eca3e7f9eeaaef62738891422e7c2a0b80feab165905f"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:3eb6a97a1d39976f360b10ff208c73afb6a4de86dd2a6212ddf65c4a6a2347d5"}, {file = "SQLAlchemy-2.0.32-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0c1c9b673d21477cec17ab10bc4decb1322843ba35b481585facd88203754fc5"}, {file = "SQLAlchemy-2.0.32-cp38-cp38-win32.whl", hash = "sha256:c41a2b9ca80ee555decc605bd3c4520cc6fef9abde8fd66b1cf65126a6922d65"}, {file = "SQLAlchemy-2.0.32-cp38-cp38-win_amd64.whl", hash = "sha256:8a37e4d265033c897892279e8adf505c8b6b4075f2b40d77afb31f7185cd6ecd"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:52fec964fba2ef46476312a03ec8c425956b05c20220a1a03703537824b5e8e1"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:328429aecaba2aee3d71e11f2477c14eec5990fb6d0e884107935f7fb6001632"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85a01b5599e790e76ac3fe3aa2f26e1feba56270023d6afd5550ed63c68552b3"}, {file = "SQLAlchemy-2.0.32-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aaf04784797dcdf4c0aa952c8d234fa01974c4729db55c45732520ce12dd95b4"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:4488120becf9b71b3ac718f4138269a6be99a42fe023ec457896ba4f80749525"}, {file = "SQLAlchemy-2.0.32-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:14e09e083a5796d513918a66f3d6aedbc131e39e80875afe81d98a03312889e6"}, {file = "SQLAlchemy-2.0.32-cp39-cp39-win32.whl", hash = "sha256:0d322cc9c9b2154ba7e82f7bf25ecc7c36fbe2d82e2933b3642fc095a52cfc78"}, {file = "SQLAlchemy-2.0.32-cp39-cp39-win_amd64.whl", hash = "sha256:7dd8583df2f98dea28b5cd53a1beac963f4f9d087888d75f22fcc93a07cf8d84"}, @@ -5757,4 +5734,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10, <3.13" -content-hash = "d1bc3be63a03d89358a0b71085cc08bbd303f04a154e5ce4588e35269b9d82fc" +content-hash = "8ebb3e24e7c48e61973a50c638ce35d9821bb56f52541c5c734be5b405f1f0cf" diff --git a/pyproject.toml b/pyproject.toml index ce932693f..55e64751a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ langchain-ollama = "^0.1.1" streamlit = "^1.37.1" deprecated = "^1.2.14" +tomli = "^2.0.1" [tool.poetry.group.dev.dependencies] ipykernel = "^6.29.4" diff --git a/src/rai/rai/cli/rai_cli.py b/src/rai/rai/cli/rai_cli.py index 02624a650..a185fab12 100644 --- a/src/rai/rai/cli/rai_cli.py +++ b/src/rai/rai/cli/rai_cli.py @@ -15,16 +15,22 @@ import argparse import glob +import logging import subprocess from pathlib import Path +import coloredlogs from langchain_community.vectorstores import FAISS from langchain_core.messages import SystemMessage -from langchain_openai import ChatOpenAI, OpenAIEmbeddings from rai.apps.talk_to_docs import ingest_documentation from rai.messages import preprocess_image from rai.messages.multimodal import HumanMultimodalMessage +from rai.utils.model_initialization import get_embeddings_model, get_llm_model + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +coloredlogs.install(level="INFO") # type: ignore def parse_whoami_package(): @@ -35,42 +41,79 @@ def parse_whoami_package(): "documentation_root", type=str, help="Path to the root of the documentation" ) parser.add_argument( - "output", + "--output", type=str, + required=False, + default=None, help="Path to the output directory", ) args = parser.parse_args() + save_dir = args.output if args.output is not None else args.documentation_root + + llm = get_llm_model(model_type="simple_model") + embeddings_model = get_embeddings_model() + + def build_docs_vector_store(): + logger.info("Building the robot docs vector store...") + faiss_index = FAISS.from_documents(docs, embeddings_model) + faiss_index.add_documents(docs) + faiss_index.save_local(save_dir) + + def build_robot_identity(): + logger.info("Building the robot identity...") + prompt = ( + "You will be given a robot's documentation. " + "Your task is to identify the robot's identity. " + "The description should cover the most important aspects of the robot with respect to human interaction, " + "as well as the robot's capabilities and limitations including sensor and actuator information. " + "If there are any images provided, make sure to take them into account by thoroughly analyzing them. " + "Your description should be thorough and detailed." + "Your reply should start with I am a ..." + ) + + images = glob.glob(args.documentation_root + "/images/*") + + messages = [SystemMessage(content=prompt)] + [ + HumanMultimodalMessage( + content=documentation, + images=[preprocess_image(image) for image in images], + ) + ] + output = llm.invoke(messages) + assert isinstance(output.content, str), "Malformed output" + + with open(save_dir + "/robot_identity.txt", "w") as f: + f.write(output.content) + logger.info("Done") docs = ingest_documentation( documentation_root=args.documentation_root + "/documentation" ) - faiss_index = FAISS.from_documents(docs, OpenAIEmbeddings()) - faiss_index.add_documents(docs) - save_dir = args.output - faiss_index.save_local(save_dir) - - prompt = ( - "You will be given a robot's documentation. " - "Your task is to identify the robot's identity. " - "The description should cover the most important aspects of the robot with respect to human interaction, " - "as well as the robot's capabilities and limitations including sensor and actuator information. " - "If there are any images provided, make sure to take them into account by thoroughly analyzing them. " - "Your description should be thorough and detailed." - "Your reply should start with I am a ..." + documentation = str([doc.page_content for doc in docs]) + n_tokens = len(documentation) // 4.0 + logger.info( + "Building the robot docs vector store... " + f"The documentation's length is {len(documentation)} chars, " + f"approximately {n_tokens} tokens" ) - llm = ChatOpenAI(model="gpt-4o-mini") - - images = glob.glob(args.documentation_root + "/images/*") - - messages = [SystemMessage(content=prompt)] + [ - HumanMultimodalMessage( - content=str([doc.page_content for doc in docs]), - images=[preprocess_image(image) for image in images], + logger.warn("Do you want to continue? (y/n)") + if input() == "y": + build_docs_vector_store() + else: + logger.info("Skipping the robot docs vector store creation.") + + logger.info( + "Building the robot identity... " + f"You can do it manually by creating {save_dir}/robot_identity.txt " + ) + logger.warn("Do you want to continue? (y/n)") + if input() == "y": + build_robot_identity() + else: + logger.info( + f"Skipping the robot identity creation. " + f"You can do it manually by creating {save_dir}/robot_identity.txt" ) - ] - output = llm.invoke(messages) - with open(save_dir + "/robot_identity.txt", "w") as f: - f.write(output.content) def create_rai_ws(): @@ -125,7 +168,7 @@ def create_rai_ws(): # TODO: Refactor this hacky solution # NOTE (mkotynia) fast solution, worth refactor in the future and testing if it generic for all setup.py file confgurations -def modify_setup_py(setup_py_path): +def modify_setup_py(setup_py_path: Path): with open(setup_py_path, "r") as file: setup_content = file.read() @@ -157,7 +200,3 @@ def modify_setup_py(setup_py_path): with open(setup_py_path, "w") as file: file.write(modified_script) - - -if __name__ == "__main__": - create_rai_ws() diff --git a/src/rai/rai/utils/__init__.py b/src/rai/rai/utils/__init__.py new file mode 100644 index 000000000..ef74fc891 --- /dev/null +++ b/src/rai/rai/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/rai/rai/utils/model_initialization.py b/src/rai/rai/utils/model_initialization.py new file mode 100644 index 000000000..bdeede0ca --- /dev/null +++ b/src/rai/rai/utils/model_initialization.py @@ -0,0 +1,110 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Literal + +import tomli + + +@dataclass +class VendorConfig: + name: str + + +@dataclass +class ModelConfig: + simple_model: str + complex_model: str + embeddings_model: str + + +@dataclass +class AWSConfig(ModelConfig): + region_name: str + + +@dataclass +class OllamaConfig(ModelConfig): + base_url: str + + +@dataclass +class RAIConfig: + vendor: VendorConfig + aws: AWSConfig + openai: ModelConfig + ollama: OllamaConfig + + +def load_config() -> RAIConfig: + with open("config.toml", "rb") as f: + config_dict = tomli.load(f) + return RAIConfig( + vendor=VendorConfig(**config_dict["vendor"]), + aws=AWSConfig(**config_dict["aws"]), + openai=ModelConfig(**config_dict["openai"]), + ollama=OllamaConfig(**config_dict["ollama"]), + ) + + +def get_llm_model(model_type: Literal["simple_model", "complex_model"]): + config = load_config() + vendor = config.vendor.name + model_config = getattr(config, vendor) + + if vendor == "openai": + from langchain_openai import ChatOpenAI + + return ChatOpenAI(model=getattr(model_config, model_type)) + elif vendor == "aws": + from langchain_aws import ChatBedrock + + return ChatBedrock( + model_id=getattr(model_config, model_type), + region_name=model_config.region_name, + ) + elif vendor == "ollama": + from langchain_ollama import ChatOllama + + return ChatOllama( + model=getattr(model_config, model_type), base_url=model_config.base_url + ) + else: + raise ValueError(f"Unknown LLM vendor: {vendor}") + + +def get_embeddings_model(): + config = load_config() + vendor = config.vendor.name + model_config = getattr(config, vendor) + + if vendor == "openai": + from langchain_openai import OpenAIEmbeddings + + return OpenAIEmbeddings(model=model_config.embeddings_model) + elif vendor == "aws": + from langchain_aws import BedrockEmbeddings + + return BedrockEmbeddings( + model_id=model_config.embeddings_model, region_name=model_config.region_name + ) + elif vendor == "ollama": + from langchain_ollama import OllamaEmbeddings + + return OllamaEmbeddings( + model=model_config.embeddings_model, base_url=model_config.base_url + ) + else: + raise ValueError(f"Unknown embeddings vendor: {vendor}")