Skip to content

Commit

Permalink
using sns fanout to lambdas instead of to sqs
Browse files Browse the repository at this point in the history
  • Loading branch information
brainboost committed Dec 23, 2023
1 parent 85b8f4f commit 17a010b
Show file tree
Hide file tree
Showing 11 changed files with 318 additions and 236 deletions.
61 changes: 36 additions & 25 deletions engines/bard.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import re
from typing import Optional
from typing import Any, Optional

import boto3
import requests
Expand Down Expand Up @@ -124,34 +124,45 @@ def __as_markdown(input: str) -> str:
esc_pattern = re.compile(f"([{re.escape(r'._-+#|{}!=()<>[]')}])")
return re.sub(esc_pattern, r"\\\1", input)

def __process_payload(payload: Any, request_id: str) -> None:
user_id = payload["user_id"]
user_context = UserContext(
user_id=f"{user_id}_{payload['chat_id']}",
request_id=request_id,
engine_id=engine_type,
username=payload["username"],
)
if "command" in payload["type"]:
process_command(input=payload["text"], context=user_context)
return

instance = create(conversation_id=user_context.conversation_id)
response = ask(
text=payload["text"],
chatbot=instance,
context=user_context,
)
user_context.save_conversation(
conversation={"request": payload["text"], "response": response},
)
payload["response"] = encode_message(response)
payload["engine"] = engine_type
logging.info(payload)
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))

def sqs_handler(event, context):
"""AWS SQS event handler"""
request_id = context.aws_request_id
logging.info(f"Request ID: {request_id}")
for record in event["Records"]:
payload = json.loads(record["body"])
user_id = payload["user_id"]
user_context = UserContext(
user_id=f"{user_id}_{payload['chat_id']}",
request_id=request_id,
engine_id=engine_type,
username=payload["username"],
)
if "command" in payload["type"]:
process_command(input=payload["text"], context=user_context)
return

instance = create(conversation_id=user_context.conversation_id)
response = ask(
text=payload["text"],
chatbot=instance,
context=user_context,
)
user_context.save_conversation(
conversation={"request": payload["text"], "response": response},
)
payload["response"] = encode_message(response)
payload["engine"] = engine_type
logging.info(payload)
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))
__process_payload(payload, request_id)

def sns_handler(event, context):
"""AWS SNS event handler"""
request_id = context.aws_request_id
logging.info(f"Request ID: {request_id}")
for record in event["Records"]:
payload = json.loads(record["Sns"]["Message"])
__process_payload(payload, request_id)

61 changes: 36 additions & 25 deletions engines/bing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import re
from typing import Any

import boto3
from EdgeGPT.EdgeGPT import CONVERSATION_STYLE_TYPE, Chatbot
Expand Down Expand Up @@ -104,35 +105,45 @@ def create() -> Chatbot:
results_queue = read_ssm_param(param_name="RESULTS_SQS_QUEUE_URL")
sqs = boto3.session.Session().client("sqs")

def __process_payload(payload: Any, request_id: str) -> None:
user_id = payload["user_id"]
user_context = UserContext(
user_id=f"{user_id}_{payload['chat_id']}",
request_id=request_id,
engine_id=engine_type,
username=payload["username"],
)
if "command" in payload["type"]:
process_command(input=payload["text"], context=user_context)
return

instance = create()
style = payload["config"].get("style", "creative")
response = ask(
text=payload["text"],
chatbot=instance,
style=style,
context=user_context,
)
user_context.save_conversation(
conversation={"request": payload["text"], "response": response},
)
payload["response"] = encode_message(response)
payload["engine"] = engine_type
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))

def sqs_handler(event, context):
"""AWS SQS event handler"""
request_id = context.aws_request_id
logging.info(f"Request ID: {request_id}")
for record in event["Records"]:
payload = json.loads(record["body"])
user_id = payload["user_id"]
user_context = UserContext(
user_id=f"{user_id}_{payload['chat_id']}",
request_id=request_id,
engine_id=engine_type,
username=payload["username"],
)
if "command" in payload["type"]:
process_command(input=payload["text"], context=user_context)
return

instance = create()
style = payload["config"].get("style", "creative")
response = ask(
text=payload["text"],
chatbot=instance,
style=style,
context=user_context,
)
user_context.save_conversation(
conversation={"request": payload["text"], "response": response},
)
payload["response"] = encode_message(response)
payload["engine"] = engine_type
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))
__process_payload(payload, request_id)

def sns_handler(event, context):
"""AWS SNS event handler"""
request_id = context.aws_request_id
logging.info(f"Request ID: {request_id}")
for record in event["Records"]:
payload = json.loads(record["Sns"]["Message"])
__process_payload(payload, request_id)
49 changes: 29 additions & 20 deletions engines/chat_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from collections import deque
from os import environ
from typing import Any

import boto3
from revChatGPT.V1 import Chatbot
Expand Down Expand Up @@ -73,29 +74,37 @@ def create(context: UserContext) -> Chatbot:
sqs = boto3.session.Session().client("sqs")


def __process_payload(payload: Any, request_id: str) -> None:
user_id = payload["user_id"]
user_context = UserContext(
user_id=f"{user_id}_{payload['chat_id']}",
request_id=request_id,
engine_id=engine_type,
username=payload["username"],
)
instance = create(user_context)
response = ask(payload["text"], instance, user_context)
user_context.save_conversation(
conversation={"request": payload["text"], "response": response},
)
payload["response"] = response
payload["response"] = encode_message(response)
payload["engine"] = engine_type
# logging.info(payload)
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))

def sqs_handler(event, context):
"""AWS SQS event handler"""
request_id = context.aws_request_id
logging.info(f"Request ID: {request_id}")
for record in event["Records"]:
payload = json.loads(record["body"])
user_id = payload["user_id"]
user_context = UserContext(
user_id=f"{user_id}_{payload['chat_id']}",
request_id=request_id,
engine_id=engine_type,
username=payload["username"],
)
if "command" in payload["type"]:
process_command(input=payload["text"], context=user_context)
return

instance = create(user_context)
response = ask(payload["text"], instance, user_context)
user_context.save_conversation(
conversation={"request": payload["text"], "response": response},
)
payload["response"] = response
payload["response"] = encode_message(response)
payload["engine"] = engine_type
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))
__process_payload(payload, request_id)

def sns_handler(event, context):
"""AWS SNS event handler"""
request_id = context.aws_request_id
logging.info(f"Request ID: {request_id}")
for record in event["Records"]:
payload = json.loads(record["Sns"]["Message"])
__process_payload(payload, request_id)
49 changes: 30 additions & 19 deletions engines/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import re
import uuid
from typing import Any

import boto3
from curl_cffi import requests
Expand Down Expand Up @@ -176,31 +177,41 @@ def __set_title(prompt: str, conversation_id: str) -> str:
sqs = boto3.session.Session().client("sqs")


def __process_payload(payload: Any, request_id: str) -> None:
user_id = payload["user_id"]
user_context = UserContext(
user_id=f"{user_id}_{payload['chat_id']}",
request_id=request_id,
engine_id=engine_type,
username=payload["username"],
)
if "command" in payload["type"]:
process_command(input=payload["text"], context=user_context)
return

response = ask(payload["text"], context=user_context)
user_context.save_conversation(
conversation={"request": payload["text"], "response": response},
)
payload["response"] = encode_message(response)
payload["engine"] = engine_type
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))

def sqs_handler(event, context):
"""AWS SQS event handler"""
request_id = context.aws_request_id
logging.info(f"Request ID: {request_id}")
for record in event["Records"]:
payload = json.loads(record["body"])
user_id = payload["user_id"]
user_context = UserContext(
user_id=f"{user_id}_{payload['chat_id']}",
request_id=request_id,
engine_id=engine_type,
username=payload["username"],
)
if "command" in payload["type"]:
process_command(input=payload["text"], context=user_context)
return

response = ask(payload["text"], context=user_context)
user_context.save_conversation(
conversation={"request": payload["text"], "response": response},
)
payload["response"] = encode_message(response)
payload["engine"] = engine_type
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))

__process_payload(payload, request_id)

def sns_handler(event, context):
"""AWS SNS event handler"""
request_id = context.aws_request_id
logging.info(f"Request ID: {request_id}")
for record in event["Records"]:
payload = json.loads(record["Sns"]["Message"])
__process_payload(payload, request_id)

# if __name__ == "__main__":
# put_request("does DALL-E uses stable diffusion?")
88 changes: 50 additions & 38 deletions engines/dalle_img.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
from typing import Any

import boto3
from BingImageCreator import ImageGen
Expand Down Expand Up @@ -32,49 +33,60 @@ def create() -> ImageGen:
results_queue = read_ssm_param(param_name="RESULTS_SQS_QUEUE_URL")
sqs = boto3.session.Session().client("sqs")

def __process_payload(payload: Any, request_id: str) -> None:
prompt = payload["text"]
list = []
if prompt is None or not prompt.strip():
return
try:
list = imageGen.get_images(prompt)
except Exception as e:
if "prompt has been blocked" in str(e):
message = escape_markdown_v2(str(e))
list = [message]
else:
logging.error(e)
logging.info(payload)
logging.info(imageGen.session.__dict__)

payload["response"] = list
user_id = payload["user_id"]
user_context = UserContext(
user_id=f"{user_id}_{payload['chat_id']}",
request_id=request_id,
engine_id=engine_type,
username=payload["username"],
)
user_context.conversation_id = request_id
try:
user_context.save_conversation(
conversation=payload,
)
except Exception as e:
logging.error(
f"Saving conversation error. User_id: {user_id}_{payload['chat_id']}, item: {payload}",
exc_info=e,
)
logging.info(list)
message = "\n".join(list)
payload["response"] = encode_message(message)
payload["engine"] = engine_type
# logging.info(payload)
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))


def sqs_handler(event, context):
"""AWS SQS event handler"""
request_id = context.aws_request_id
logging.info(f"Request ID: {request_id}")
for record in event["Records"]:
payload = json.loads(record["body"])
prompt = payload["text"]
list = []
if prompt is None or not prompt.strip():
return
try:
list = imageGen.get_images(prompt)
except Exception as e:
if "prompt has been blocked" in str(e):
message = escape_markdown_v2(str(e))
list = [message]
else:
logging.error(e)
logging.info(payload)
logging.info(imageGen.session.__dict__)
__process_payload(payload, request_id)

payload["response"] = list
user_id = payload["user_id"]
user_context = UserContext(
user_id=f"{user_id}_{payload['chat_id']}",
request_id=request_id,
engine_id=engine_type,
username=payload["username"],
)
user_context.conversation_id = request_id
try:
user_context.save_conversation(
conversation=payload,
)
except Exception as e:
logging.error(
f"Saving conversation error. User_id: {user_id}_{payload['chat_id']}, item: {payload}",
exc_info=e,
)

logging.info(list)
message = "\n".join(list)
payload["response"] = encode_message(message)
payload["engine"] = engine_type
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))
def sns_handler(event, context):
"""AWS SNS event handler"""
request_id = context.aws_request_id
logging.info(f"Request ID: {request_id}")
for record in event["Records"]:
payload = json.loads(record["Sns"]["Message"])
__process_payload(payload, request_id)
Loading

0 comments on commit 17a010b

Please sign in to comment.