Skip to content

Commit

Permalink
feat(rai_cli): improved UX (#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejmajek authored Sep 5, 2024
1 parent eb64596 commit bcc9e3b
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 60 deletions.
19 changes: 19 additions & 0 deletions config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[vendor]
name = "openai" # openai, aws, ollama

[aws]
simple_model = "anthropic.claude-3-haiku-20240307-v1:0"
complex_model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
embeddings_model = "amazon.titan-embed-text-v1"
region_name = "us-east-1"

[openai]
simple_model = "gpt-4o-mini"
complex_model = "gpt-4o"
embeddings_model = "text-embedding-ada-002"

[ollama]
simple_model = "llama3.1"
complex_model = "llama3.1:70b"
embeddings_model = "llama3.1"
base_url = "http://localhost:11434"
33 changes: 5 additions & 28 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ langchain-ollama = "^0.1.1"

streamlit = "^1.37.1"
deprecated = "^1.2.14"
tomli = "^2.0.1"
[tool.poetry.group.dev.dependencies]
ipykernel = "^6.29.4"

Expand Down
103 changes: 71 additions & 32 deletions src/rai/rai/cli/rai_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,22 @@

import argparse
import glob
import logging
import subprocess
from pathlib import Path

import coloredlogs
from langchain_community.vectorstores import FAISS
from langchain_core.messages import SystemMessage
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

from rai.apps.talk_to_docs import ingest_documentation
from rai.messages import preprocess_image
from rai.messages.multimodal import HumanMultimodalMessage
from rai.utils.model_initialization import get_embeddings_model, get_llm_model

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
coloredlogs.install(level="INFO") # type: ignore


def parse_whoami_package():
Expand All @@ -35,42 +41,79 @@ def parse_whoami_package():
"documentation_root", type=str, help="Path to the root of the documentation"
)
parser.add_argument(
"output",
"--output",
type=str,
required=False,
default=None,
help="Path to the output directory",
)
args = parser.parse_args()
save_dir = args.output if args.output is not None else args.documentation_root

llm = get_llm_model(model_type="simple_model")
embeddings_model = get_embeddings_model()

def build_docs_vector_store():
logger.info("Building the robot docs vector store...")
faiss_index = FAISS.from_documents(docs, embeddings_model)
faiss_index.add_documents(docs)
faiss_index.save_local(save_dir)

def build_robot_identity():
logger.info("Building the robot identity...")
prompt = (
"You will be given a robot's documentation. "
"Your task is to identify the robot's identity. "
"The description should cover the most important aspects of the robot with respect to human interaction, "
"as well as the robot's capabilities and limitations including sensor and actuator information. "
"If there are any images provided, make sure to take them into account by thoroughly analyzing them. "
"Your description should be thorough and detailed."
"Your reply should start with I am a ..."
)

images = glob.glob(args.documentation_root + "/images/*")

messages = [SystemMessage(content=prompt)] + [
HumanMultimodalMessage(
content=documentation,
images=[preprocess_image(image) for image in images],
)
]
output = llm.invoke(messages)
assert isinstance(output.content, str), "Malformed output"

with open(save_dir + "/robot_identity.txt", "w") as f:
f.write(output.content)
logger.info("Done")

docs = ingest_documentation(
documentation_root=args.documentation_root + "/documentation"
)
faiss_index = FAISS.from_documents(docs, OpenAIEmbeddings())
faiss_index.add_documents(docs)
save_dir = args.output
faiss_index.save_local(save_dir)

prompt = (
"You will be given a robot's documentation. "
"Your task is to identify the robot's identity. "
"The description should cover the most important aspects of the robot with respect to human interaction, "
"as well as the robot's capabilities and limitations including sensor and actuator information. "
"If there are any images provided, make sure to take them into account by thoroughly analyzing them. "
"Your description should be thorough and detailed."
"Your reply should start with I am a ..."
documentation = str([doc.page_content for doc in docs])
n_tokens = len(documentation) // 4.0
logger.info(
"Building the robot docs vector store... "
f"The documentation's length is {len(documentation)} chars, "
f"approximately {n_tokens} tokens"
)
llm = ChatOpenAI(model="gpt-4o-mini")

images = glob.glob(args.documentation_root + "/images/*")

messages = [SystemMessage(content=prompt)] + [
HumanMultimodalMessage(
content=str([doc.page_content for doc in docs]),
images=[preprocess_image(image) for image in images],
logger.warn("Do you want to continue? (y/n)")
if input() == "y":
build_docs_vector_store()
else:
logger.info("Skipping the robot docs vector store creation.")

logger.info(
"Building the robot identity... "
f"You can do it manually by creating {save_dir}/robot_identity.txt "
)
logger.warn("Do you want to continue? (y/n)")
if input() == "y":
build_robot_identity()
else:
logger.info(
f"Skipping the robot identity creation. "
f"You can do it manually by creating {save_dir}/robot_identity.txt"
)
]
output = llm.invoke(messages)
with open(save_dir + "/robot_identity.txt", "w") as f:
f.write(output.content)


def create_rai_ws():
Expand Down Expand Up @@ -125,7 +168,7 @@ def create_rai_ws():

# TODO: Refactor this hacky solution
# NOTE (mkotynia) fast solution, worth refactor in the future and testing if it generic for all setup.py file confgurations
def modify_setup_py(setup_py_path):
def modify_setup_py(setup_py_path: Path):
with open(setup_py_path, "r") as file:
setup_content = file.read()

Expand Down Expand Up @@ -157,7 +200,3 @@ def modify_setup_py(setup_py_path):

with open(setup_py_path, "w") as file:
file.write(modified_script)


if __name__ == "__main__":
create_rai_ws()
13 changes: 13 additions & 0 deletions src/rai/rai/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
110 changes: 110 additions & 0 deletions src/rai/rai/utils/model_initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Literal

import tomli


@dataclass
class VendorConfig:
name: str


@dataclass
class ModelConfig:
simple_model: str
complex_model: str
embeddings_model: str


@dataclass
class AWSConfig(ModelConfig):
region_name: str


@dataclass
class OllamaConfig(ModelConfig):
base_url: str


@dataclass
class RAIConfig:
vendor: VendorConfig
aws: AWSConfig
openai: ModelConfig
ollama: OllamaConfig


def load_config() -> RAIConfig:
with open("config.toml", "rb") as f:
config_dict = tomli.load(f)
return RAIConfig(
vendor=VendorConfig(**config_dict["vendor"]),
aws=AWSConfig(**config_dict["aws"]),
openai=ModelConfig(**config_dict["openai"]),
ollama=OllamaConfig(**config_dict["ollama"]),
)


def get_llm_model(model_type: Literal["simple_model", "complex_model"]):
config = load_config()
vendor = config.vendor.name
model_config = getattr(config, vendor)

if vendor == "openai":
from langchain_openai import ChatOpenAI

return ChatOpenAI(model=getattr(model_config, model_type))
elif vendor == "aws":
from langchain_aws import ChatBedrock

return ChatBedrock(
model_id=getattr(model_config, model_type),
region_name=model_config.region_name,
)
elif vendor == "ollama":
from langchain_ollama import ChatOllama

return ChatOllama(
model=getattr(model_config, model_type), base_url=model_config.base_url
)
else:
raise ValueError(f"Unknown LLM vendor: {vendor}")


def get_embeddings_model():
config = load_config()
vendor = config.vendor.name
model_config = getattr(config, vendor)

if vendor == "openai":
from langchain_openai import OpenAIEmbeddings

return OpenAIEmbeddings(model=model_config.embeddings_model)
elif vendor == "aws":
from langchain_aws import BedrockEmbeddings

return BedrockEmbeddings(
model_id=model_config.embeddings_model, region_name=model_config.region_name
)
elif vendor == "ollama":
from langchain_ollama import OllamaEmbeddings

return OllamaEmbeddings(
model=model_config.embeddings_model, base_url=model_config.base_url
)
else:
raise ValueError(f"Unknown embeddings vendor: {vendor}")

0 comments on commit bcc9e3b

Please sign in to comment.