Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relaxed safety settings for gemni #28

Merged
merged 6 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 64 additions & 8 deletions engines/claude.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import json
import logging
import os
import re
import uuid
from typing import Any
from typing import Any, Optional
from urllib.parse import urlparse

import boto3
import boto3.session
import requests as req
from curl_cffi import requests

from .common_utils import (
Expand Down Expand Up @@ -50,6 +54,53 @@ def process_command(input: str, context: UserContext) -> None:
logging.error(f"Unknown command {command}")


def get_content_type(file_path):
extension = os.path.splitext(file_path)[-1].lower()
if extension == ".pdf":
return "application/pdf"
elif extension == ".txt":
return "text/plain"
elif extension == ".csv":
return "text/csv"
else:
return "application/octet-stream"


def upload_attachment(s3_uri: str) -> Optional[str]:
url = "https://claude.ai/api/convert_document"
parsed = urlparse(s3_uri)
s3_bucket, s3_path, file_name = (
parsed.netloc,
parsed.path,
parsed.path.split("/")[-1],
)
logging.info(f"Downloading file {file_name} from s3 bucket {s3_bucket}")
tmp_file = f"/tmp/{file_name}"
# tmp_file = "/tmp/Regulamin_promocji.pdf"
session = boto3.session.Session()
session.client("s3").download_file(Bucket=s3_bucket, Key=s3_path, Filename=tmp_file)
# session.client("s3").download_file(
# Bucket="chatbotstack-s3-bucket-dev-tmp",
# Key="att/Regulamin_promocji.pdf",
# Filename=tmp_file,
# )
files = {
"file": (file_name, open(tmp_file, "rb"), get_content_type(tmp_file)),
"orgUuid": (None, organization_id),
}
response = req.post(url, headers=headers, files=files)
logging.info(f"Uploaded file {file_name}, response '{response.status_code}'")
os.remove(tmp_file)
if response.status_code == 200:
return response.json()
else:
logging.error(f"POST upload returned {response.status_code} {response.reason}")
logging.info(response.content.decode("utf-8"))
logging.info(headers)
logging.info(files)
return None


def ask(text: str, context: UserContext, attachment=None):
if "/ping" in text:
return "pong"
Expand All @@ -58,15 +109,17 @@ def ask(text: str, context: UserContext, attachment=None):
__set_conversation(conversation_id=conversation_uuid)
context.conversation_id = conversation_uuid
__set_title(prompt=text, conversation_id=conversation_uuid)
attachments = []
# attachment_response = upload_attachment(attachment)
# logging.info(attachment_response)
# attachments = [attachment_response]
# if attachment:
# logging.info(f"Uploading attachment {attachment}")
# attachment_response = upload_attachment(attachment)
# if attachment_response:
# attachments = [attachment_response]
# else:
# return {"File upload failed. Please try again."}
if not attachment:
attachments = []
# logging.error("File upload failed: {}".format(attachment))
attachments = []
payload = json.dumps(
{
"completion": {
Expand Down Expand Up @@ -188,15 +241,17 @@ def __process_payload(payload: Any, request_id: str) -> None:
if "command" in payload["type"]:
process_command(input=payload["text"], context=user_context)
return

response = ask(payload["text"], context=user_context)
response = ask(
text=payload["text"], context=user_context, attachment=payload.get("file", None)
)
user_context.save_conversation(
conversation={"request": payload["text"], "response": response},
)
payload["response"] = encode_message(response)
payload["engine"] = engine_type
sns.publish(TopicArn=result_topic, Message=json.dumps(payload))


def sns_handler(event, context):
"""AWS SNS event handler"""
request_id = context.aws_request_id
Expand All @@ -205,5 +260,6 @@ def sns_handler(event, context):
payload = json.loads(record["Sns"]["Message"])
__process_payload(payload, request_id)


# if __name__ == "__main__":
# put_request("does DALL-E uses stable diffusion?")
# ask("a ile kosztuje ta usługa miesięczne?")
198 changes: 198 additions & 0 deletions engines/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import json
import logging
import os
import re
import uuid
from pathlib import Path
from typing import Any, Optional
from urllib.parse import urlparse

import boto3
import google.generativeai as genai
from google.ai.generativelanguage import (
Part,
)
from google.generativeai.client import (
_ClientManager,
)

from .common_utils import encode_message, read_ssm_param
from .user_context import UserContext

logging.basicConfig()
logging.getLogger().setLevel("INFO")

engine_type = "gemini"

bucket_name = read_ssm_param(param_name="BOT_S3_BUCKET")
result_topic = read_ssm_param(param_name="RESULT_SNS_TOPIC_ARN")
sns = boto3.session.Session().client("sns")
_model: genai.GenerativeModel = None
_vision_model = None


def process_command(input: str, context: UserContext) -> None:
command = input.removeprefix(prefix="/").lower()
logging.info(f"Processing command {command} for {context.user_id}")
if "reset" in command:
context.reset_conversation()
logging.info(f"Conversation hass been reset for {context.user_id}")
return
logging.error(f"Unknown command {command}")


def get_image(s3_uri: str) -> str:
file_name = urlparse(s3_uri).path.split("/")[-1]
logging.info(f"Downloading file 'att/{file_name}' from s3 bucket {bucket_name}")
tmp_file = f"/tmp/{file_name}"
session = boto3.session.Session()
session.client("s3").download_file(
Bucket=bucket_name,
Key=f"att/{file_name}",
Filename=tmp_file,
)
if not (img := Path(tmp_file)).exists():
logging.error(
f"File {tmp_file} does not exist. Problem to download from s3 '{s3_uri}'"
)
raise FileNotFoundError(f"Could not find image: {img}")
return tmp_file


def ask(
text: str,
file_path: str,
context: UserContext,
) -> str:
if "/ping" in text:
return "pong"

if file_path:
logging.info("Downloading image")
image = get_image(file_path)

logging.info("Asking Gemini...")
response = _vision_model.generate_content(
[
Part(
inline_data={
"mime_type": "image/jpeg",
"data": Path(image).read_bytes(),
}
),
Part(text=text),
],
stream=True,
)
else:
if context.conversation_id is None:
context.conversation_id = str(uuid.uuid4())

# conversation.append(
# Content.to_json(Content(role="user", parts=[Part(text=text)]))
# )

logging.info("Asking Gemini...")
response = _model.generate_content(
# [Content.from_json(content) for content in conversation],
text,
stream=True,
)

answer = ""
for chunk in response:
if len(chunk.parts) < 1 or "text" not in chunk.parts[0]:
continue
answer += chunk.parts[0].text
return __as_markdown(answer)


def create(conversation_id: Optional[str]) -> None:
"""Initialize model API https://ai.google.dev/api"""

logging.info("Create chatbot instance")
proxy_url = read_ssm_param(param_name="SOCKS5_URL")
os.environ["http_proxy"] = proxy_url
logging.info(f"Initializing Google AI module with proxy {proxy_url}")
generation_config = {
"temperature": 0.8,
"top_p": 1,
"top_k": 32,
"max_output_tokens": 4096,
}
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_ONLY_HIGH",
},
]
global _model
_model = genai.GenerativeModel(
model_name="gemini-pro",
generation_config=generation_config,
safety_settings=safety_settings,
)
global _vision_model
_vision_model = genai.GenerativeModel(
model_name="gemini-pro-vision",
generation_config=generation_config,
safety_settings=safety_settings,
)
api_key = read_ssm_param(param_name="GEMINI_API_KEY")
client_manager = _ClientManager()
client_manager.configure(api_key=api_key)
_model._client = client_manager.get_default_client("generative")
_vision_model._client = client_manager.get_default_client("generative")
logging.info("Google AI module initialized")


def __as_markdown(input: str) -> str:
input = re.sub(r"(?<!\*)\*(?!\*)", "\\\\*", input)
input = re.sub(r"\*{2,}", "*", input)
esc_pattern = re.compile(f"([{re.escape(r'._-+#|{}!=()<>[]')}])")
return re.sub(esc_pattern, r"\\\1", input)


def __process_payload(payload: Any, request_id: str) -> None:
user_id = payload["user_id"]
user_context = UserContext(
user_id=f"{user_id}_{payload['chat_id']}",
request_id=request_id,
engine_id=engine_type,
username=payload["username"],
)
if "command" in payload["type"]:
process_command(input=payload["text"], context=user_context)
return

create(conversation_id=user_context.conversation_id)
response = ask(
text=payload["text"],
file_path=payload.get("file", None),
context=user_context,
)
user_context.save_conversation(
conversation={"request": payload["text"], "response": response},
)
payload["response"] = encode_message(response)
payload["engine"] = engine_type
# logging.info(payload)
sns.publish(TopicArn=result_topic, Message=json.dumps(payload))


def sns_handler(event, context):
"""AWS SNS event handler"""
request_id = context.aws_request_id
logging.info(f"Request ID: {request_id}")
for record in event["Records"]:
payload = json.loads(record["Sns"]["Message"])
__process_payload(payload, request_id)
1 change: 1 addition & 0 deletions engines/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ git+https://github.com/brainboost/EdgeGPT.git
git+https://github.com/brainboost/ChatGPT.git
git+https://github.com/brainboost/BingImageCreator.git
bardapi
google-generativeai
deepl
websockets
requests
Expand Down
Loading