Skip to content

Commit

Permalink
Merge pull request #16 from igooch/messages
Browse files Browse the repository at this point in the history
Adds the ability to ingest the optional field messages for the chat endpoint, which stores the history of the chat.
  • Loading branch information
zmerlynn authored Feb 29, 2024
2 parents 0eb7376 + 417884b commit 205ba78
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 24 deletions.
15 changes: 9 additions & 6 deletions genai/api/genai_api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import io
import json
import requests

from typing import List
from vertexai.language_models import ChatMessage

logging.basicConfig(
level=logging.DEBUG,
Expand Down Expand Up @@ -51,7 +52,7 @@


app = FastAPI(
docs_url='/genai_docs',
docs_url='/genai_docs',
redoc_url=None,
title="GenAI Quickstart APIs",
description="Core APIs for the GenAI Quickstart for Gaming",
Expand All @@ -64,7 +65,7 @@
)


GENAI_GEMINI_ENDPOINT= os.environ['GENAI_GEMINI_ENDPOINT']
GENAI_GEMINI_ENDPOINT= os.environ['GENAI_GEMINI_ENDPOINT']
GENAI_TEXT_ENDPOINT = os.environ['GENAI_TEXT_ENDPOINT']
GENAI_CHAT_ENDPOINT = os.environ['GENAI_CHAT_ENDPOINT']
GENAI_CODE_ENDPOINT = os.environ['GENAI_CODE_ENDPOINT']
Expand Down Expand Up @@ -102,6 +103,7 @@ class Payload_Vertex_Gemini(BaseModel):
class Payload_Chat(BaseModel):
prompt: str
context: str | None = ''
message_history: List[ChatMessage] | None = []
max_output_tokens: int | None = 1024
temperature: float | None = 0.2
top_p: float | None = 0.8
Expand Down Expand Up @@ -223,7 +225,7 @@ def genai_gemini(payload: Payload_Vertex_Gemini):
def genai_text(payload: Payload_Text):
try:
request_payload = {
'prompt': payload.prompt,
'prompt': payload.prompt,
'max_output_tokens': payload.max_output_tokens,
'temperature': payload.temperature,
'top_p': payload.top_p,
Expand All @@ -246,6 +248,7 @@ def genai_chat(payload: Payload_Chat):
request_payload = {
'prompt': payload.prompt,
'context': payload.context,
'message_history': payload.message_history,
'max_output_tokens': payload.max_output_tokens,
'temperature': payload.temperature,
'top_p': payload.top_p,
Expand All @@ -266,7 +269,7 @@ def genai_chat(payload: Payload_Chat):
def genai_code(payload: Payload_Code):
try:
request_payload = {
'prompt': payload.prompt,
'prompt': payload.prompt,
'max_output_tokens': payload.max_output_tokens,
'temperature': payload.temperature,
'top_p': payload.top_p,
Expand All @@ -287,7 +290,7 @@ def genai_code(payload: Payload_Code):
def genai_image(payload: Payload_Image):
try:
request_payload = {
'prompt': payload.prompt,
'prompt': payload.prompt,
'number_of_images': payload.number_of_images,
'seed': payload.seed,
}
Expand Down
9 changes: 5 additions & 4 deletions genai/api/genai_api/src/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

@mock.patch('requests.post')
def test_genai_text(mock_post):

# Define a mock response content as a JSON string
expected_response = {'mocked_key': 'mocked_value'}
mock_response_content = json.dumps(expected_response).encode()
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_genai_text(mock_post):

@mock.patch('requests.post')
def test_genai_chat(mock_post):

# Define a mock response content as a JSON string
expected_response = {'mocked_key': 'mocked_value'}
mock_response_content = json.dumps(expected_response).encode()
Expand All @@ -73,6 +73,7 @@ def test_genai_chat(mock_post):
payload = {
"prompt": "test prompt",
"context": "my context",
"message_history": [{"author": "me", "content": "my content"}],
"max_output_tokens": 1024,
"temperature": 0.2,
"top_p": 0.8,
Expand All @@ -90,7 +91,7 @@ def test_genai_chat(mock_post):

@mock.patch('requests.post')
def test_genai_code(mock_post):

# Define a mock response content as a JSON string
expected_response = {'mocked_key': 'mocked_value'}
mock_response_content = json.dumps(expected_response).encode()
Expand Down Expand Up @@ -123,7 +124,7 @@ def test_genai_code(mock_post):

@mock.patch('requests.post')
def test_genai_image(mock_post):

# Define a mock response content as a JSON string
expected_response = {'mocked_key': 'mocked_value'}
mock_response_content = json.dumps(expected_response).encode()
Expand Down
11 changes: 6 additions & 5 deletions genai/api/vertex_chat_api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@

from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.responses import StreamingResponse
from utils.model_util import Google_Cloud_GenAI
import io
import os, sys
import json
import requests
import sys
import logging
import requests
from typing import List
from vertexai.language_models import ChatMessage

logging.basicConfig(
level=logging.DEBUG,
Expand Down Expand Up @@ -68,6 +67,7 @@ def get_gcp_metadata():
class Payload_Vertex_Chat(BaseModel):
prompt: str
context: str | None = ''
message_history: List[ChatMessage] | None = []
max_output_tokens: int | None = 1024
temperature: float | None = 0.2
top_p: float | None = 0.8
Expand All @@ -88,6 +88,7 @@ def vertex_llm_chat(payload: Payload_Vertex_Chat):
request_payload = {
'prompt': payload.prompt,
'context': payload.context,
'message_history': payload.message_history,
'max_output_tokens': payload.max_output_tokens,
'temperature': payload.temperature,
'top_p': payload.top_p,
Expand Down
1 change: 1 addition & 0 deletions genai/api/vertex_chat_api/src/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_genai(mock_post):
payload = {
"prompt": "test prompt",
"context": "my test context",
"messages": [{"author": "me", "content": "my content"}],
"max_output_tokens": 1024,
"temperature": 0.2,
"top_p": 0.8,
Expand Down
19 changes: 10 additions & 9 deletions genai/api/vertex_chat_api/src/utils/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@ class Google_Cloud_GenAI:
def __init__(self, GCP_PROJECT_ID, GCP_REGION, MODEL_TYPE):
if GCP_PROJECT_ID=="":
print(f'[ WARNING ] GCP_PROJECT_ID ENV variable is empty. Be sure to set the GCP_PROJECT_ID ENV variable.')

if GCP_REGION=="":
print(f'[ WARNING ] GCP_REGION ENV variable is empty. Be sure to set the GCP_REGION ENV variable.')

if MODEL_TYPE=="":
print(f'[ WARNING ] MODEL_TYPE ENV variable is empty. Be sure to set the MODEL_TYPE ENV variable.')
print(f'[ WARNING ] MODEL_TYPE ENV variable is empty. Be sure to set the MODEL_TYPE ENV variable.')

self.GCP_PROJECT_ID = GCP_PROJECT_ID
self.GCP_REGION = GCP_REGION
self.MODEL_TYPE = MODEL_TYPE
self.pretrained_model = f'{MODEL_TYPE.lower()}@001'

self.vertexai = vertexai.init(project=GCP_PROJECT_ID, location=GCP_REGION)

if MODEL_TYPE.lower() == 'text-bison':
self.model = TextGenerationModel.from_pretrained(self.pretrained_model)
elif MODEL_TYPE.lower() == 'chat-bison':
Expand All @@ -52,7 +52,7 @@ def __init__(self, GCP_PROJECT_ID, GCP_REGION, MODEL_TYPE):
print(f'[ ERROR ] No MODEL_TYPE specified or MODEL_TYPE is incorrect. Expecting MODEL_TYPE ENV var of "text-bison", "chat-bison", "code-bison", or "codechat-bison".')
sys.exit()

def call_llm(self, prompt, temperature=0.2, max_output_tokens=256, top_p=0.8, top_k=40, context='', chat_examples=[], code_suffix=''):
def call_llm(self, prompt, temperature=0.2, max_output_tokens=256, top_p=0.8, top_k=40, context='', chat_examples=[], message_history=[], code_suffix=''):
if self.MODEL_TYPE.lower() == 'text-bison':
try:
parameters = {
Expand All @@ -70,7 +70,7 @@ def call_llm(self, prompt, temperature=0.2, max_output_tokens=256, top_p=0.8, to
except Exception as e:
print(f'[ EXCEPTION ] At call_llm for text-bison. {e}')
return ''

elif self.MODEL_TYPE.lower() == 'chat-bison':
try:
'''
Expand All @@ -89,6 +89,7 @@ def call_llm(self, prompt, temperature=0.2, max_output_tokens=256, top_p=0.8, to
chat = self.model.start_chat(
context=context,
examples=chat_examples,
message_history=message_history,
temperature=temperature,
max_output_tokens=max_output_tokens,
top_p=top_p,
Expand All @@ -100,16 +101,16 @@ def call_llm(self, prompt, temperature=0.2, max_output_tokens=256, top_p=0.8, to
except Exception as e:
print(f'[ EXCEPTION ] At call_chat for chat-bison. {e}')
return ''

elif self.MODEL_TYPE.lower() == 'code-bison':
'''A language model that generates code.'''
try:
response = self.model.predict(prefix=prompt, temperature=temperature, max_output_tokens=max_output_tokens, suffix=code_suffix)
return response
except Exception as e:
print(f'[ EXCEPTION ] At call_chat for codechat-bison. {e}')
return ''
return ''

elif self.MODEL_TYPE.lower() == 'codechat-bison':
'''CodeChatModel represents a model that is capable of completing code.'''
try:
Expand Down

0 comments on commit 205ba78

Please sign in to comment.