Skip to content

Commit

Permalink
Merge pull request #5 from san99tiago/sqs-batch-processing
Browse files Browse the repository at this point in the history
Enabled SQS-batch processing with partial failues approach
  • Loading branch information
san99tiago authored May 1, 2023
2 parents a98d454 + 83dd6a2 commit c6fc6e1
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 92 deletions.
1 change: 1 addition & 0 deletions cdk/stacks/cdk_api_gateway_sqs_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def configure_sqs_event_source_for_lambda(self):
self.queue,
enabled=True,
batch_size=5,
report_batch_item_failures=True, # Necessary for processing batches from SQS
)
self.lambda_function.add_event_source(self.sqs_event_source)

Expand Down
74 changes: 41 additions & 33 deletions lambda/src/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,49 +9,57 @@
# External imports
from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.utilities.typing import LambdaContext
from aws_lambda_powertools.utilities.data_classes import event_source, SQSEvent
from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, process_partial_response
from aws_lambda_powertools.utilities.batch.types import PartialItemFailureResponse
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord


tracer = Tracer(service="MessagesAPIService")
logger = Logger(service="MessagesAPIService", log_uncaught_exceptions=True)
logger.append_keys(owner="san99tiago")
logger = Logger(service="MessagesAPIService", log_uncaught_exceptions=True, owner="san99tiago")
processor = BatchProcessor(event_type=EventType.SQS)


@tracer.capture_method
def process_messages(event: SQSEvent):
tracer.put_metadata(key="total_messages", value=len(list(event.records)))
# Multiple records can be delivered in a single event, so loop through them
for record in event.records:
logger.debug(record.raw_event)
payload = record.body
sqs_message_id = record.message_id
logger.info(payload)
logger.info(sqs_message_id)

# Simulate a "time" processing delay for the messages
logger.debug("Processing message")
tracer.put_annotation(key="sqs_id", value=sqs_message_id)
time.sleep(4)
logger.debug("Finished processing message")
def process_message(record: SQSRecord):
# Add message id for each log statement so we know which message is being processed
logger.append_keys(message_id=record.message_id)

# Batch will call this function for each record and will handle partial failures
logger.info(record.body)

# Simulate a "time" processing delay for the messages
logger.debug("Processing message")
tracer.put_annotation(key="sqs_id", value=record.message_id)
time.sleep(4)
logger.debug("Finished processing message")

try:
# Validate "Message" key on input, otherwise return failure
# Note: if input does not contain "Message" key, this will raise an error
message = json.loads(payload)["Message"]
logger.info("Message: {}".format(message))
return True
message = json.loads(record.body)["Message"]
logger.info(f"Message: {message}")
return True
except:
logger.exception("Failed to process message")
raise


@logger.inject_lambda_context(log_event=True)
@event_source(data_class=SQSEvent)
@tracer.capture_lambda_handler
def handler(event: SQSEvent, context: LambdaContext) -> str:
logger.debug("Starting messages processing")
def handler(event: dict, context: LambdaContext) -> PartialItemFailureResponse:
logger.info("Starting messages processing")
tracer.put_metadata(key="details", value="messages processing handler")
try:
result = process_messages(event)
logger.debug(result)
except Exception as e:
logger.exception("Error processing the messages")
raise RuntimeError("Processing failed for the input messages") from e
logger.debug("Finished messages processing")

return "Successfully processed message"

number_of_records = len(event.get("Records", []))
logger.debug(f"Number of messages is: {number_of_records}")
tracer.put_metadata(key="total_messages", value=number_of_records)

batch_response = process_partial_response(
event=event,
record_handler=process_message,
processor=processor,
context=context
)
logger.info("Finished messages processing")

return batch_response
File renamed without changes.
File renamed without changes.
38 changes: 0 additions & 38 deletions lambda/tests/test_event_02_good_multiple.json

This file was deleted.

31 changes: 10 additions & 21 deletions lambda/tests/test_lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import unittest

# External imports
from aws_lambda_powertools.utilities.data_classes import event_source, SQSEvent
from aws_lambda_powertools.utilities.data_classes import SQSEvent
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
from moto import mock_sts

# Add path to find lambda directory for own imports
Expand Down Expand Up @@ -38,25 +39,12 @@ def test_process_messages_success_single(self):
Test successful process_messages call for a single message.
"""
# Load pre-configured event for current test case
event = self.load_test_event("test_event_01_good_single.json")
event = self.load_test_event("test_event_01_good.json")

# Middleware to load event with correct SQSEvent data class
# Middleware to load event with correct SQSEvent and SQSRecord data classes
event_sqs = SQSEvent(event)
result = _lambda.process_messages(event_sqs)

self.assertEqual(result, True)

@mock_sts()
def test_process_messages_success_multiple(self):
"""
Test successful process_messages call for multiple messages.
"""
# Load pre-configured event for current test case
event = self.load_test_event("test_event_02_good_multiple.json")

# Middleware to load event with correct SQSEvent data class
event_sqs = SQSEvent(event)
result = _lambda.process_messages(event_sqs)
sqs_record = SQSRecord(event_sqs.get("Records", [])[0])
result = _lambda.process_message(sqs_record)

self.assertEqual(result, True)

Expand All @@ -66,14 +54,15 @@ def test_process_messages_error(self):
Test errors on process_messages call due to wrong message format.
"""
# Load pre-configured event for current test case
event = self.load_test_event("test_event_03_bad.json")
event = self.load_test_event("test_event_02_bad.json")

# Middleware to load event with correct SQSEvent data class
# Middleware to load event with correct SQSEvent and SQSRecord data classes
event_sqs = SQSEvent(event)
sqs_record = SQSRecord(event_sqs.get("Records", [])[0])

# Expected an exception intentionally, otherwise fails
with self.assertRaises(Exception):
_lambda.process_messages(event_sqs)
_lambda.process_message(sqs_record)


if __name__ == "__main__":
Expand Down

0 comments on commit c6fc6e1

Please sign in to comment.