Skip to content

Commit

Permalink
moving AI results to the SNS instead of SQS queue and added DLQ redrive
Browse files Browse the repository at this point in the history
  • Loading branch information
brainboost committed Dec 24, 2023
1 parent 17a010b commit 4467302
Show file tree
Hide file tree
Showing 14 changed files with 118 additions and 171 deletions.
13 changes: 3 additions & 10 deletions engines/bard.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
"__Secure-1PAPISID",
]
bucket_name = read_ssm_param(param_name="BOT_S3_BUCKET")
results_queue = read_ssm_param(param_name="RESULTS_SQS_QUEUE_URL")
sqs = boto3.session.Session().client("sqs")
result_topic = read_ssm_param(param_name="RESULT_SNS_TOPIC_ARN")
sns = boto3.session.Session().client("sns")


def process_command(input: str, context: UserContext) -> None:
Expand Down Expand Up @@ -148,15 +148,8 @@ def __process_payload(payload: Any, request_id: str) -> None:
payload["response"] = encode_message(response)
payload["engine"] = engine_type
logging.info(payload)
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))
sns.publish(TopicArn=result_topic, Message=json.dumps(payload))

def sqs_handler(event, context):
"""AWS SQS event handler"""
request_id = context.aws_request_id
logging.info(f"Request ID: {request_id}")
for record in event["Records"]:
payload = json.loads(record["body"])
__process_payload(payload, request_id)

def sns_handler(event, context):
"""AWS SNS event handler"""
Expand Down
14 changes: 3 additions & 11 deletions engines/bing.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def create() -> Chatbot:
return chatbot


results_queue = read_ssm_param(param_name="RESULTS_SQS_QUEUE_URL")
sqs = boto3.session.Session().client("sqs")
result_topic = read_ssm_param(param_name="RESULT_SNS_TOPIC_ARN")
sns = boto3.session.Session().client("sns")

def __process_payload(payload: Any, request_id: str) -> None:
user_id = payload["user_id"]
Expand All @@ -130,15 +130,7 @@ def __process_payload(payload: Any, request_id: str) -> None:
)
payload["response"] = encode_message(response)
payload["engine"] = engine_type
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))

def sqs_handler(event, context):
"""AWS SQS event handler"""
request_id = context.aws_request_id
logging.info(f"Request ID: {request_id}")
for record in event["Records"]:
payload = json.loads(record["body"])
__process_payload(payload, request_id)
sns.publish(TopicArn=result_topic, Message=json.dumps(payload))

def sns_handler(event, context):
"""AWS SNS event handler"""
Expand Down
14 changes: 3 additions & 11 deletions engines/chat_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def create(context: UserContext) -> Chatbot:
return chatbot


results_queue = read_ssm_param(param_name="RESULTS_SQS_QUEUE_URL")
sqs = boto3.session.Session().client("sqs")
result_topic = read_ssm_param(param_name="RESULT_SNS_TOPIC_ARN")
sns = boto3.session.Session().client("sns")


def __process_payload(payload: Any, request_id: str) -> None:
Expand All @@ -90,16 +90,8 @@ def __process_payload(payload: Any, request_id: str) -> None:
payload["response"] = response
payload["response"] = encode_message(response)
payload["engine"] = engine_type
# logging.info(payload)
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))
sns.publish(TopicArn=result_topic, Message=json.dumps(payload))

def sqs_handler(event, context):
"""AWS SQS event handler"""
request_id = context.aws_request_id
logging.info(f"Request ID: {request_id}")
for record in event["Records"]:
payload = json.loads(record["body"])
__process_payload(payload, request_id)

def sns_handler(event, context):
"""AWS SNS event handler"""
Expand Down
14 changes: 3 additions & 11 deletions engines/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ def __set_title(prompt: str, conversation_id: str) -> str:
headers["Cookie"] = cookies_str
organization_id = __get_organization()

results_queue = read_ssm_param(param_name="RESULTS_SQS_QUEUE_URL")
sqs = boto3.session.Session().client("sqs")
result_topic = read_ssm_param(param_name="RESULT_SNS_TOPIC_ARN")
sns = boto3.session.Session().client("sns")


def __process_payload(payload: Any, request_id: str) -> None:
Expand All @@ -195,16 +195,8 @@ def __process_payload(payload: Any, request_id: str) -> None:
)
payload["response"] = encode_message(response)
payload["engine"] = engine_type
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))
sns.publish(TopicArn=result_topic, Message=json.dumps(payload))

def sqs_handler(event, context):
"""AWS SQS event handler"""
request_id = context.aws_request_id
logging.info(f"Request ID: {request_id}")
for record in event["Records"]:
payload = json.loads(record["body"])
__process_payload(payload, request_id)

def sns_handler(event, context):
"""AWS SNS event handler"""
request_id = context.aws_request_id
Expand Down
7 changes: 3 additions & 4 deletions engines/dalle_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def create() -> ImageGen:


imageGen = create()
results_queue = read_ssm_param(param_name="RESULTS_SQS_QUEUE_URL")
sqs = boto3.session.Session().client("sqs")
result_topic = read_ssm_param(param_name="RESULT_SNS_TOPIC_ARN")
sns = boto3.session.Session().client("sns")

def __process_payload(payload: Any, request_id: str) -> None:
prompt = payload["text"]
Expand Down Expand Up @@ -71,8 +71,7 @@ def __process_payload(payload: Any, request_id: str) -> None:
message = "\n".join(list)
payload["response"] = encode_message(message)
payload["engine"] = engine_type
# logging.info(payload)
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))
sns.publish(TopicArn=result_topic, Message=json.dumps(payload))


def sqs_handler(event, context):
Expand Down
14 changes: 3 additions & 11 deletions engines/deepl_tr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
auth_key = read_ssm_param(param_name="DEEPL_AUTHKEY")

translator = Translator(auth_key)
results_queue = read_ssm_param(param_name="RESULTS_SQS_QUEUE_URL")
sqs = boto3.session.Session().client("sqs")
result_topic = read_ssm_param(param_name="RESULT_SNS_TOPIC_ARN")
sns = boto3.session.Session().client("sns")


def __parse_languages(lang: str) -> list:
Expand All @@ -37,17 +37,9 @@ def __process_payload(payload: Any, request_id: str) -> None:

payload["engine"] = lang.replace("-", "\-")
payload["response"] = encode_message(result)
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))
sns.publish(TopicArn=result_topic, Message=json.dumps(payload))


def sqs_handler(event, context):
"""AWS SQS event handler"""
request_id = context.aws_request_id
logging.info(f"Request ID: {request_id}")
for record in event["Records"]:
payload = json.loads(record["body"])
__process_payload(payload, request_id)

def sns_handler(event, context):
"""AWS SNS event handler"""
request_id = context.aws_request_id
Expand Down
13 changes: 1 addition & 12 deletions engines/ideogram_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@
retrieve_metadata_url = f"{base_url}/api/images/retrieve_metadata_request_id/"
get_images_url = f"{base_url}/api/images/direct/"

results_queue = read_ssm_param(param_name="RESULTS_SQS_QUEUE_URL")
ideogram_result_queue = results_queue.replace(
results_queue.split("/")[-1], "Ideogram-Result-Queue"
)
sqs = boto3.session.Session().client("sqs")
token = read_ssm_param(param_name="IDEOGRAM_TOKEN")
user_id = read_ssm_param(param_name="IDEOGRAM_USER")
Expand All @@ -40,6 +36,7 @@
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) \
Gecko/20100101 Firefox/117.0",
}
ideogram_result_queue = sqs.get_queue_url(QueueName="Ideogram-Result-Queue")["QueueUrl"]


def request_images(prompt: str) -> str:
Expand Down Expand Up @@ -87,14 +84,6 @@ def __process_payload(payload: Any, request_id: str) -> None:
payload["queue_url"] = ideogram_result_queue
send_retrieving_event(payload)

def sqs_handler(event, context):
"""AWS SQS event handler"""
request_id = context.aws_request_id
logging.info(f"Request ID: {request_id}")
for record in event["Records"]:
payload = json.loads(record["body"])
__process_payload(payload, request_id)

def sns_handler(event, context):
"""AWS SNS event handler"""
request_id = context.aws_request_id
Expand Down
6 changes: 3 additions & 3 deletions engines/ideogram_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
retrieve_metadata_url = "https://ideogram.ai/api/images/retrieve_metadata_request_id/"
get_images_url = "https://ideogram.ai/api/images/direct/"

results_queue = read_ssm_param(param_name="RESULTS_SQS_QUEUE_URL")
result_topic = read_ssm_param(param_name="RESULT_SNS_TOPIC_ARN")
sns = boto3.session.Session().client("sns")
sqs = boto3.session.Session().client("sqs")


Expand Down Expand Up @@ -84,5 +85,4 @@ def sqs_handler(event, context):
f"Saving conversation error. User_id: {user_id}_{payload['chat_id']}, item: {payload}",
exc_info=e,
)
# logging.info(payload)
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))
sns.publish(TopicArn=result_topic, Message=json.dumps(payload))
9 changes: 5 additions & 4 deletions engines/monsterapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def ask(
"response": escape_markdown_v2(response["message"]),
"engine": engine_type,
}
sqs.send_message(QueueUrl=results_queue, MessageBody=err_message)
sns = boto3.session.Session().client("sns")
sns.publish(TopicArn=result_topic, Message=json.dumps(err_message))
return

process_id = response_body["process_id"]
Expand All @@ -65,8 +66,7 @@ def ask(


token = read_ssm_param(param_name="MONSTERAPI_TOKEN")
results_queue = read_ssm_param(param_name="RESULTS_SQS_QUEUE_URL")
sqs = boto3.session.Session().client("sqs")
result_topic = read_ssm_param(param_name="RESULT_SNS_TOPIC_ARN")

def __process_payload(payload: Any, request_id: str) -> None:
user_id = payload["user_id"]
Expand All @@ -79,7 +79,8 @@ def __process_payload(payload: Any, request_id: str) -> None:
question = payload["text"]
if "/ping" in question:
payload["response"] = "pong"
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))
sns = boto3.session.Session().client("sns")
sns.publish(TopicArn=result_topic, Message=json.dumps(payload))
return

if "command" in payload["type"]:
Expand Down
8 changes: 4 additions & 4 deletions engines/monsterapi_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
engine_type = "llama"
fetch_url = "https://api.monsterapi.ai/v1/status/"

results_queue = read_ssm_param(param_name="RESULTS_SQS_QUEUE_URL")
result_topic = read_ssm_param(param_name="RESULT_SNS_TOPIC_ARN")
sns = boto3.session.Session().client("sns")
token = read_ssm_param(param_name="MONSTERAPI_TOKEN")
sqs = boto3.session.Session().client("sqs")

headers = {
"authorization": f"Bearer {token}",
Expand Down Expand Up @@ -81,5 +81,5 @@ def callback_handler(event, context) -> None:
f"Error on saving conversation_id: {process_id}, request_id:{process_id}, payload: {payload}",
exc_info=e,
)
logging.info(payload)
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload))
# logging.info(payload)
sns.publish(TopicArn=result_topic, Message=json.dumps(payload))
66 changes: 39 additions & 27 deletions lambda/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import time
from typing import Any

import boto3
import boto3.session
Expand Down Expand Up @@ -40,7 +41,7 @@
logging.getLogger().setLevel("INFO")

user_config = UserConfig()
sns = boto3.client("sns")
sns = boto3.session.Session().client("sns")


telegram_token = read_ssm_param(param_name="TELEGRAM_TOKEN")
Expand All @@ -60,7 +61,6 @@

# Telegram commands


async def reset(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
if (
update.effective_user is None
Expand Down Expand Up @@ -249,33 +249,45 @@ async def redrive_dlq(update: Update, context: ContextTypes.DEFAULT_TYPE) -> Non
)


def __start_redrive_dlq() -> dict:
def __start_redrive_dlq() -> Any:
session = boto3.session.Session()
region = session.region_name
sts_client = session.client("sts")
account_id = sts_client.get_caller_identity()["Account"]
dlq_arn = f"arn:aws:sqs:{region}:{account_id}:Request-Queues-DLQ"
sqs = session.client("sqs")
try:
dlq_response = sqs.start_message_move_task(SourceArn=dlq_arn)
if dlq_response is not None:
handle = dlq_response["TaskHandle"]
logging.info(f"Redrive task started: {handle}")
while True:
list_response = sqs.list_message_move_tasks(SourceArn=dlq_arn)
results = list_response["Results"][0]
logging.info(f"Results: {results}")
if results["Status"] == "RUNNING":
logging.info("Delaying..")
time.sleep(2)
else:
break

logging.info(f"Finished: {results}")
return results
except ClientError as e:
logging.error(f"Redriving DLQ messages error :{e}")
return f"DLQ Redrive failed: {e}"
for queue_url in sqs.list_queues()['QueueUrls']:
if '-DLQ' in queue_url:
[region, account_id, name] = __parse_sqs_url(queue_url)
dlq_arn = f"arn:aws:sqs:{region}:{account_id}:{name}"
logging.info(f"DLQ ARN: {dlq_arn}")
try:
dlq_response = sqs.start_message_move_task(SourceArn=dlq_arn)
if dlq_response is not None:
handle = dlq_response["TaskHandle"]
logging.info(f"Redrive task started: {handle}")
except ClientError as e:
logging.error(f"Redriving DLQ messages error :{e}")
return f"DLQ Redrive failed: {e}"
done = False
while not done:
list_response = sqs.list_message_move_tasks(SourceArn=dlq_arn)
results = list_response["Results"]
logging.info(f"Results: {results}")
if len(results) == 0:
return []
for result in results:
if result["Status"] == "RUNNING":
logging.info("Delaying 2s ...")
time.sleep(2)
break
done = True
logging.info(f"Finished DLQ Redrive: {results}")
return results

def __parse_sqs_url(url: str) -> tuple[str, str, str]:
"""Parses the SQS URL and extracts the region, account ID, and queue name"""
parts = url.split('/')
region = parts[2].split('.')[1]
account_id = parts[3]
name = parts[4]
return region, account_id, name


# Translation handlers
Expand Down
3 changes: 2 additions & 1 deletion lambda/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def response_handler(event, context) -> None:
"""Result SQS processing handler."""

for record in event["Records"]:
payload = json.loads(record["body"])
payload = json.loads(record["Sns"]["Message"])
# payload = json.loads(record["body"])
chat_id = payload["chat_id"]
message_id = payload["message_id"]
message = decode_message(payload["response"])
Expand Down
Loading

0 comments on commit 4467302

Please sign in to comment.