Skip to content

Commit

Permalink
Use an SQS queue for the downloader; generalize dispatcher (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
austinbyers authored Apr 17, 2018
1 parent ff6bdb3 commit 5a35e01
Show file tree
Hide file tree
Showing 37 changed files with 1,045 additions and 818 deletions.
1 change: 1 addition & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ omit=venv/*

[report]
fail_under=85
show_missing = True
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ ignore-docstrings=yes
ignore-imports=no

# Minimum lines number of a similarity.
min-similarity-lines=4
min-similarity-lines=7


[SPELLING]
Expand Down
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ install:
script:
- coverage run manage.py unit_test
- coverage report # Required coverage threshold specified in .coveragerc
- find . -name '*.py' -not -path './docs/source/*' -exec pylint '{}' + # Config in .pylintrc
- mypy . --ignore-missing-imports
- pylint lambda_functions rules tests *.py -j 1 # Config in .pylintrc
- mypy lambda_functions rules *.py --disallow-untyped-defs --ignore-missing-imports --warn-unused-ignores
- bandit -r . # Configuration in .bandit
- sphinx-build -W docs/source docs/build
after_success:
Expand Down
27 changes: 20 additions & 7 deletions lambda_functions/analyzer/analyzer_aws_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from botocore.exceptions import ClientError

if __package__:
from lambda_functions.analyzer.binary_info import BinaryInfo
# BinaryInfo is imported here just for the type annotation - the cyclic import will resolve
from lambda_functions.analyzer.binary_info import BinaryInfo # pylint: disable=cyclic-import
from lambda_functions.analyzer.common import LOGGER
else:
# mypy complains about duplicate definitions
Expand All @@ -24,6 +25,10 @@
SQS = boto3.resource('sqs')


class FileDownloadError(Exception):
"""File can't be downloaded from S3 with a 4XX error code - do not retry."""


def download_from_s3(
bucket_name: str, object_key: str, download_path: str) -> Tuple[str, Dict[str, str]]:
"""Download an object from S3 to the given download path.
Expand All @@ -35,11 +40,21 @@ def download_from_s3(
Returns:
Last modified timestamp (i.e. object upload timestamp), object metadata.
Raises:
FileDownloadError: If the file couldn't be downloaded because
"""
s3_object = S3.Object(bucket_name, object_key)
s3_object.download_file(download_path)
last_modified = str(s3_object.last_modified) # UTC timestamp, e.g. '2017-09-04 04:49:06-00:00'
return last_modified, s3_object.metadata
try:
s3_object = S3.Object(bucket_name, object_key)
s3_object.download_file(download_path)
# UTC timestamp, e.g. '2017-09-04 04:49:06-00:00'
last_modified = str(s3_object.last_modified)
return last_modified, s3_object.metadata
except ClientError as error:
if 400 <= error.response['ResponseMetadata']['HTTPStatusCode'] < 500:
raise FileDownloadError(error)
else:
raise


def _elide_string_middle(text: str, max_length: int) -> str:
Expand Down Expand Up @@ -81,8 +96,6 @@ def delete_sqs_messages(queue_url: str, receipts: List[str]) -> None:
queue_url: The URL of the SQS queue containing the messages.
receipts: List of SQS receipt handles.
"""
if not receipts:
return
LOGGER.info('Deleting %d SQS receipt(s) from %s', len(receipts), queue_url)
SQS.Queue(queue_url).delete_messages(
Entries=[
Expand Down
4 changes: 2 additions & 2 deletions lambda_functions/analyzer/binary_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _download_from_s3(self) -> None:
self.bucket_name, self.object_key, self.download_path)
self.download_time_ms = (time.time() - start_time) * 1000

def __enter__(self):
def __enter__(self) -> Any: # mypy/typing doesn't support recursive type yet
"""Download the binary from S3 and run YARA analysis."""
self._download_from_s3()
self.computed_sha, self.computed_md5 = file_hash.compute_hashes(self.download_path)
Expand All @@ -71,7 +71,7 @@ def __enter__(self):

return self

def __exit__(self, exception_type, exception_value, traceback):
def __exit__(self, exception_type: Any, exception_value: Any, traceback: Any) -> None:
"""Shred the downloaded binary and delete it from disk."""
# Note: This runs even during exception handling (it is the "with" context).
subprocess.check_call(['shred', '--remove', self.download_path])
Expand Down
139 changes: 100 additions & 39 deletions lambda_functions/analyzer/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""AWS Lambda function for testing a binary against a list of YARA rules."""
# Expects the following environment variables:
# SQS_QUEUE_URL: URL of the queue from which messages originated (needed for message deletion).
# YARA_MATCHES_DYNAMO_TABLE_NAME: Name of the Dynamo table which stores YARA match results.
# YARA_ALERTS_SNS_TOPIC_ARN: ARN of the SNS topic which should be alerted on a YARA match.
# Expects a binary YARA rules file to be at './compiled_yara_rules.bin'
import json
import os
from typing import Any, Dict
from typing import Any, Dict, Generator, List, Tuple
import urllib.parse

from botocore.exceptions import ClientError as BotoError
from botocore.exceptions import ClientError

if __package__:
# Imported by unit tests or other external code.
Expand All @@ -29,28 +29,90 @@
NUM_YARA_RULES = ANALYZER.num_rules


def analyze_lambda_handler(event_data: Dict[str, Any], lambda_context) -> Dict[str, Dict[str, Any]]:
"""Lambda function entry point.
def _s3_objects(s3_records: List[Dict[str, Any]]) -> Generator[Tuple[str, str], None, None]:
"""Build list of objects in the given S3 record.
Args:
event_data: [dict] of the form: {
'Records': [
{
"s3": {
"object": {
"key": "FileName.txt"
},
"bucket": {
"name": "mybucket"
}
s3_records: List of S3 event records: [
{
's3': {
'object': {
'key': (str)
},
'bucket': {
'name': (str)
}
}
},
...
]
Yields:
(bucket_name, object_key) string tuple
"""
for record in s3_records:
try:
bucket_name = record['s3']['bucket']['name']
object_key = urllib.parse.unquote_plus(record['s3']['object']['key'])
yield bucket_name, object_key
except (KeyError, TypeError):
LOGGER.exception('Skipping invalid S3 record %s', record)


def _objects_to_analyze(event: Dict[str, Any]) -> Generator[Tuple[str, str], None, None]:
"""Parse the invocation event into a list of objects to analyze.
Args:
event: Invocation event, from either the dispatcher or an S3 bucket
Yields:
(bucket_name, object_key) string tuples to analyze
"""
if set(event) == {'messages', 'queue_url'}:
LOGGER.info('Invoked from dispatcher with %d messages', len(event['messages']))
for sqs_record in event['messages']:
try:
s3_records = json.loads(sqs_record['body'])['Records']
except (json.JSONDecodeError, KeyError, TypeError):
LOGGER.exception('Skipping invalid SQS message %s', sqs_record)
continue
yield from _s3_objects(s3_records)
else:
LOGGER.info('Invoked with dictionary (S3 Event)')
yield from _s3_objects(event['Records'])


def analyze_lambda_handler(event: Dict[str, Any], lambda_context: Any) -> Dict[str, Dict[str, Any]]:
"""Analyzer Lambda function entry point.
Args:
event: SQS message batch sent by the dispatcher: {
'messages': [
{
'body': (str) JSON-encoded S3 put event: {
'Records': [
{
's3': {
'object': {
'key': (str)
},
'bucket': {
'name': (str)
}
}
},
...
]
},
'receipt': (str) SQS message receipt handle,
'receive_count': (int) Approx. # of times this has been received
},
...
],
'SQSReceipts': [...] # SQS receipt handles (to be deleted after processing).
'queue_url': (str) SQS queue url from which the message originated
}
There can be any number of S3objects, but no more than 10 SQS receipts.
The Records are the same format as the S3 Put event, which means the analyzer could be
directly linked to an S3 bucket notification if needed.
Alternatively, the event can be an S3 Put Event dictionary (with no sqs information).
This allows the analyzer to be linked directly to an S3 bucket notification if needed.
lambda_context: LambdaContext object (with .function_version).
Returns:
Expand All @@ -70,36 +132,35 @@ def analyze_lambda_handler(event_data: Dict[str, Any], lambda_context) -> Dict[s
try:
lambda_version = int(lambda_context.function_version)
except ValueError:
LOGGER.warning('Invoked $LATEST instead of a versioned function')
lambda_version = -1

LOGGER.info('Processing %d record(s)', len(event_data['Records']))
for record in event_data['Records']:
bucket_name = record['s3']['bucket']['name']
s3_key = urllib.parse.unquote_plus(record['s3']['object']['key'])
LOGGER.info('Analyzing "%s:%s"', bucket_name, s3_key)
for bucket_name, object_key in _objects_to_analyze(event):
LOGGER.info('Analyzing "%s:%s"', bucket_name, object_key)

with binary_info.BinaryInfo(bucket_name, s3_key, ANALYZER) as binary:
result[binary.s3_identifier] = binary.summary()
binaries.append(binary)
try:
with binary_info.BinaryInfo(bucket_name, object_key, ANALYZER) as binary:
result[binary.s3_identifier] = binary.summary()
binaries.append(binary)
except analyzer_aws_lib.FileDownloadError:
LOGGER.exception('Unable to download %s from %s', object_key, bucket_name)
continue

if binary.yara_matches:
LOGGER.warning('%s matched YARA rules: %s', binary, binary.matched_rule_ids)
binary.save_matches_and_alert(
lambda_version, os.environ['YARA_MATCHES_DYNAMO_TABLE_NAME'],
os.environ['YARA_ALERTS_SNS_TOPIC_ARN'])
else:
LOGGER.info('%s did not match any YARA rules', binary)
if binary.yara_matches:
LOGGER.warning('%s matched YARA rules: %s', binary, binary.matched_rule_ids)
binary.save_matches_and_alert(
lambda_version, os.environ['YARA_MATCHES_DYNAMO_TABLE_NAME'],
os.environ['YARA_ALERTS_SNS_TOPIC_ARN'])

# Delete all of the SQS receipts (mark them as completed).
analyzer_aws_lib.delete_sqs_messages(
os.environ['SQS_QUEUE_URL'],
event_data.get('SQSReceipts', [])
)
receipts_to_delete = [msg['receipt'] for msg in event.get('messages', [])]
if receipts_to_delete:
analyzer_aws_lib.delete_sqs_messages(event['queue_url'], receipts_to_delete)

# Publish metrics.
try:
analyzer_aws_lib.put_metric_data(NUM_YARA_RULES, binaries)
except BotoError:
except ClientError:
LOGGER.exception('Error saving metric data')

return result
6 changes: 3 additions & 3 deletions lambda_functions/batcher/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json
import logging
import os
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

import boto3

Expand Down Expand Up @@ -164,7 +164,7 @@ def __init__(self, bucket_name: str, prefix: Optional[str],
self.finished = False # Have we finished enumerating all of the S3 bucket?

@property
def continuation_token(self):
def continuation_token(self) -> str:
return self.kwargs.get('ContinuationToken')

def next_page(self) -> List[str]:
Expand All @@ -187,7 +187,7 @@ def next_page(self) -> List[str]:
return [obj['Key'] for obj in response['Contents']]


def batch_lambda_handler(event: Dict[str, str], lambda_context) -> int:
def batch_lambda_handler(event: Dict[str, str], lambda_context: Any) -> int:
"""Entry point for the batch Lambda function.
Args:
Expand Down
19 changes: 7 additions & 12 deletions lambda_functions/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@
DISPATCH_ZIPFILE = 'lambda_dispatcher'

DOWNLOAD_SOURCE = os.path.join(LAMBDA_DIR, 'downloader', 'main.py')
DOWNLOAD_DEPENDENCIES = os.path.join(LAMBDA_DIR, 'downloader', 'cbapi_1.3.4.zip')
DOWNLOAD_ZIPFILE = 'lambda_downloader'


def _build_analyzer(target_directory):
def _build_analyzer(target_directory: str) -> None:
"""Build the YARA analyzer Lambda deployment package."""
print('Creating analyzer deploy package...')
pathlib.Path(os.path.join(ANALYZE_SOURCE, 'main.py')).touch()
Expand Down Expand Up @@ -58,23 +57,23 @@ def _build_analyzer(target_directory):
shutil.rmtree(temp_package_dir)


def _build_batcher(target_directory):
def _build_batcher(target_directory: str) -> None:
"""Build the batcher Lambda deployment package."""
print('Creating batcher deploy package...')
pathlib.Path(BATCH_SOURCE).touch() # Change last modified time to force new Lambda deploy
with zipfile.ZipFile(os.path.join(target_directory, BATCH_ZIPFILE + '.zip'), 'w') as pkg:
pkg.write(BATCH_SOURCE, os.path.basename(BATCH_SOURCE))


def _build_dispatcher(target_directory):
def _build_dispatcher(target_directory: str) -> None:
"""Build the dispatcher Lambda deployment package."""
print('Creating dispatcher deploy package...')
pathlib.Path(DISPATCH_SOURCE).touch()
with zipfile.ZipFile(os.path.join(target_directory, DISPATCH_ZIPFILE + '.zip'), 'w') as pkg:
pkg.write(DISPATCH_SOURCE, os.path.basename(DISPATCH_SOURCE))


def _build_downloader(target_directory):
def _build_downloader(target_directory: str) -> None:
"""Build the downloader Lambda deployment package."""
print('Creating downloader deploy package...')
pathlib.Path(DOWNLOAD_SOURCE).touch()
Expand All @@ -83,12 +82,8 @@ def _build_downloader(target_directory):
if os.path.exists(temp_package_dir):
shutil.rmtree(temp_package_dir)

# Extract cbapi library.
with zipfile.ZipFile(DOWNLOAD_DEPENDENCIES, 'r') as deps:
deps.extractall(temp_package_dir)

# Pip install backoff library (has no native dependencies).
pip.main(['install', '--quiet', '--target', temp_package_dir, 'backoff'])
# Pip install cbapi library (has no native dependencies).
pip.main(['install', '--quiet', '--target', temp_package_dir, 'cbapi==1.3.6'])

# Copy Lambda code into the package.
shutil.copy(DOWNLOAD_SOURCE, temp_package_dir)
Expand All @@ -98,7 +93,7 @@ def _build_downloader(target_directory):
shutil.rmtree(temp_package_dir)


def build(target_directory, downloader=False):
def build(target_directory: str, downloader: bool = False) -> None:
"""Build Lambda deployment packages.
Args:
Expand Down
Loading

0 comments on commit 5a35e01

Please sign in to comment.