Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rai_cli): improved UX #189

Merged
merged 7 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[vendor]
name = "openai" # openai, aws, ollama

[aws]
simple_model = "anthropic.claude-3-haiku-20240307-v1:0"
complex_model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
embeddings_model = "amazon.titan-embed-text-v1"
region_name = "us-east-1"

[openai]
simple_model = "gpt-4o-mini"
complex_model = "gpt-4o"
embeddings_model = "text-embedding-ada-002"

[ollama]
simple_model = "llama3.1"
complex_model = "llama3.1:70b"
embeddings_model = "llama3.1"
base_url = "http://localhost:11434"
33 changes: 5 additions & 28 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ langchain-ollama = "^0.1.1"

streamlit = "^1.37.1"
deprecated = "^1.2.14"
tomli = "^2.0.1"
[tool.poetry.group.dev.dependencies]
ipykernel = "^6.29.4"

Expand Down
103 changes: 71 additions & 32 deletions src/rai/rai/cli/rai_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,22 @@

import argparse
import glob
import logging
import subprocess
from pathlib import Path

import coloredlogs
from langchain_community.vectorstores import FAISS
from langchain_core.messages import SystemMessage
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

from rai.apps.talk_to_docs import ingest_documentation
from rai.messages import preprocess_image
from rai.messages.multimodal import HumanMultimodalMessage
from rai.utils.model_initialization import get_embeddings_model, get_llm_model

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
coloredlogs.install(level="INFO") # type: ignore


def parse_whoami_package():
Expand All @@ -35,42 +41,79 @@ def parse_whoami_package():
"documentation_root", type=str, help="Path to the root of the documentation"
)
parser.add_argument(
"output",
"--output",
type=str,
required=False,
default=None,
help="Path to the output directory",
)
args = parser.parse_args()
save_dir = args.output if args.output is not None else args.documentation_root

llm = get_llm_model(model_type="simple_model")
embeddings_model = get_embeddings_model()

def build_docs_vector_store():
logger.info("Building the robot docs vector store...")
faiss_index = FAISS.from_documents(docs, embeddings_model)
faiss_index.add_documents(docs)
faiss_index.save_local(save_dir)

def build_robot_identity():
logger.info("Building the robot identity...")
prompt = (
"You will be given a robot's documentation. "
"Your task is to identify the robot's identity. "
"The description should cover the most important aspects of the robot with respect to human interaction, "
"as well as the robot's capabilities and limitations including sensor and actuator information. "
"If there are any images provided, make sure to take them into account by thoroughly analyzing them. "
"Your description should be thorough and detailed."
"Your reply should start with I am a ..."
)

images = glob.glob(args.documentation_root + "/images/*")

messages = [SystemMessage(content=prompt)] + [
HumanMultimodalMessage(
content=documentation,
images=[preprocess_image(image) for image in images],
)
]
output = llm.invoke(messages)
assert isinstance(output.content, str), "Malformed output"

with open(save_dir + "/robot_identity.txt", "w") as f:
f.write(output.content)
logger.info("Done")

docs = ingest_documentation(
documentation_root=args.documentation_root + "/documentation"
)
faiss_index = FAISS.from_documents(docs, OpenAIEmbeddings())
faiss_index.add_documents(docs)
save_dir = args.output
faiss_index.save_local(save_dir)

prompt = (
"You will be given a robot's documentation. "
"Your task is to identify the robot's identity. "
"The description should cover the most important aspects of the robot with respect to human interaction, "
"as well as the robot's capabilities and limitations including sensor and actuator information. "
"If there are any images provided, make sure to take them into account by thoroughly analyzing them. "
"Your description should be thorough and detailed."
"Your reply should start with I am a ..."
documentation = str([doc.page_content for doc in docs])
n_tokens = len(documentation) // 4.0
logger.info(
"Building the robot docs vector store... "
f"The documentation's length is {len(documentation)} chars, "
f"approximately {n_tokens} tokens"
)
llm = ChatOpenAI(model="gpt-4o-mini")

images = glob.glob(args.documentation_root + "/images/*")

messages = [SystemMessage(content=prompt)] + [
HumanMultimodalMessage(
content=str([doc.page_content for doc in docs]),
images=[preprocess_image(image) for image in images],
logger.warn("Do you want to continue? (y/n)")
if input() == "y":
build_docs_vector_store()
else:
logger.info("Skipping the robot docs vector store creation.")

logger.info(
"Building the robot identity... "
f"You can do it manually by creating {save_dir}/robot_identity.txt "
)
logger.warn("Do you want to continue? (y/n)")
if input() == "y":
build_robot_identity()
else:
logger.info(
f"Skipping the robot identity creation. "
f"You can do it manually by creating {save_dir}/robot_identity.txt"
)
]
output = llm.invoke(messages)
with open(save_dir + "/robot_identity.txt", "w") as f:
f.write(output.content)


def create_rai_ws():
Expand Down Expand Up @@ -125,7 +168,7 @@ def create_rai_ws():

# TODO: Refactor this hacky solution
# NOTE (mkotynia) fast solution, worth refactor in the future and testing if it generic for all setup.py file confgurations
def modify_setup_py(setup_py_path):
def modify_setup_py(setup_py_path: Path):
with open(setup_py_path, "r") as file:
setup_content = file.read()

Expand Down Expand Up @@ -157,7 +200,3 @@ def modify_setup_py(setup_py_path):

with open(setup_py_path, "w") as file:
file.write(modified_script)


if __name__ == "__main__":
create_rai_ws()
13 changes: 13 additions & 0 deletions src/rai/rai/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
110 changes: 110 additions & 0 deletions src/rai/rai/utils/model_initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Literal

import tomli


@dataclass
class VendorConfig:
name: str


@dataclass
class ModelConfig:
simple_model: str
complex_model: str
embeddings_model: str


@dataclass
class AWSConfig(ModelConfig):
region_name: str


@dataclass
class OllamaConfig(ModelConfig):
base_url: str


@dataclass
class RAIConfig:
vendor: VendorConfig
aws: AWSConfig
openai: ModelConfig
ollama: OllamaConfig


def load_config() -> RAIConfig:
with open("config.toml", "rb") as f:
config_dict = tomli.load(f)
return RAIConfig(
vendor=VendorConfig(**config_dict["vendor"]),
aws=AWSConfig(**config_dict["aws"]),
openai=ModelConfig(**config_dict["openai"]),
ollama=OllamaConfig(**config_dict["ollama"]),
)


def get_llm_model(model_type: Literal["simple_model", "complex_model"]):
config = load_config()
vendor = config.vendor.name
model_config = getattr(config, vendor)

if vendor == "openai":
from langchain_openai import ChatOpenAI

return ChatOpenAI(model=getattr(model_config, model_type))
elif vendor == "aws":
from langchain_aws import ChatBedrock

return ChatBedrock(
model_id=getattr(model_config, model_type),
region_name=model_config.region_name,
)
elif vendor == "ollama":
from langchain_ollama import ChatOllama

return ChatOllama(
model=getattr(model_config, model_type), base_url=model_config.base_url
)
else:
raise ValueError(f"Unknown LLM vendor: {vendor}")


def get_embeddings_model():
config = load_config()
vendor = config.vendor.name
model_config = getattr(config, vendor)

if vendor == "openai":
from langchain_openai import OpenAIEmbeddings

return OpenAIEmbeddings(model=model_config.embeddings_model)
elif vendor == "aws":
from langchain_aws import BedrockEmbeddings

return BedrockEmbeddings(
model_id=model_config.embeddings_model, region_name=model_config.region_name
)
elif vendor == "ollama":
from langchain_ollama import OllamaEmbeddings

return OllamaEmbeddings(
model=model_config.embeddings_model, base_url=model_config.base_url
)
else:
raise ValueError(f"Unknown embeddings vendor: {vendor}")