Skip to content

Commit

Permalink
Merge pull request #560 from GoogleCloudPlatform/update_rag_scaffold
Browse files Browse the repository at this point in the history
update langchain_rag scaffold to use gemini instead of PaLM
sanjanalreddy authored Dec 16, 2024
2 parents d4870ca + 764f9d8 commit 45ca25e
Showing 3 changed files with 29 additions and 24 deletions.
19 changes: 9 additions & 10 deletions scaffolds/langchain_rag/app/services.py
Original file line number Diff line number Diff line change
@@ -7,11 +7,10 @@
from app.settings import Config
from google.cloud import aiplatform
from langchain.chains import RetrievalQA
from langchain.document_loaders import DataFrameLoader
from langchain.embeddings import VertexAIEmbeddings
from langchain.llms import VertexAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain_chroma import Chroma
from langchain_community.document_loaders import DataFrameLoader
from langchain_google_vertexai import VertexAI, VertexAIEmbeddings


class StorageService:
@@ -185,14 +184,14 @@ def create_rag_service():
staging_bucket=Config.BUCKET,
)
embedding = VertexAIEmbeddings(
model_name=Config.PALM_EMBBEDING_MODEL,
model_name=Config.GEMINI_EMBBEDING_MODEL,
)
llm = VertexAI(
model_name=Config.PALM_TEXT_MODEL,
temperature=Config.PALM_TEMPERATURE,
max_output_tokens=Config.PALM_MAX_OUTPUT_TOKENS,
top_p=Config.PALM_TOP_P,
top_k=Config.PALM_TOP_K,
model_name=Config.GEMINI_TEXT_MODEL,
temperature=Config.GEMINI_TEMPERATURE,
max_output_tokens=Config.GEMINI_MAX_OUTPUT_TOKENS,
top_p=Config.GEMINI_TOP_P,
top_k=Config.GEMINI_TOP_K,
)
return RAGService(
embedding_engine=embedding,
16 changes: 9 additions & 7 deletions scaffolds/langchain_rag/app/settings.py
Original file line number Diff line number Diff line change
@@ -34,14 +34,16 @@ class Config:
)

# Retrieval Service Config
PALM_TEXT_MODEL = os.environ.get("PALM_MODEL", "text-bison")
PALM_EMBBEDING_MODEL = os.environ.get(
"PALM_EMBBEDING_MODEL", "textembedding-gecko"
GEMINI_TEXT_MODEL = os.environ.get("GEMINI_MODEL", "gemini-1.5-flash-001")
GEMINI_EMBBEDING_MODEL = os.environ.get(
"GEMINI_EMBBEDING_MODEL", "text-embedding-004"
)
GEMINI_TEMPERATURE = float(os.environ.get("GEMINI_TEMPERATURE", 0.5))
GEMINI_TOP_P = float(os.environ.get("GEMINI_TOP_P", 0.95))
GEMINI_TOP_K = int(os.environ.get("GEMINI_TOP_K", 40))
GEMINI_MAX_OUTPUT_TOKENS = int(
os.environ.get("GEMINI_MAX_OUTPUT_TOKENS", 1024)
)
PALM_TEMPERATURE = float(os.environ.get("PALM_TEMPERATURE", 0.0))
PALM_TOP_P = float(os.environ.get("PALM_TOP_P", 0.95))
PALM_TOP_K = int(os.environ.get("PALM_TOP_K", 40))
PALM_MAX_OUTPUT_TOKENS = int(os.environ.get("PALM_MAX_OUTPUT_TOKENS", 1024))
RETRIEVAL_CHUNK_SIZE = int(os.environ.get("RETRIEVAL_CHUNK_SIZE", 2000))
RETRIEVAL_OVERLAP_SIZE = int(os.environ.get("RETRIEVAL_CHUNK_SIZE", 0))
RETRIEVAL_SEARCH_TYPE = os.environ.get(
18 changes: 11 additions & 7 deletions scaffolds/langchain_rag/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
Flask==3.0.0
gunicorn==20.1.0
pre-commit==3.7.1
pytest==7.0.1
google-cloud-aiplatform==1.35.0
pandas==2.0.3
langchain==0.0.315
chromadb==0.4.14
gcsfs
pre_commit==4.0.1
pytest==8.3.3
google-cloud-aiplatform==1.74.0
langchain==0.3.12
langchain-chroma==0.1.4
langchain-community==0.3.12
langchain-core==0.3.25
langchain-google-vertexai==2.0.9
langchain-text-splitters==0.3.3
jupyterlab
pydantic==2.9.2
pandas==2.2.2

0 comments on commit 45ca25e

Please sign in to comment.