diff --git a/lambdaFunctions/getRedditDataFunction/lambda_function.py b/lambdaFunctions/getRedditDataFunction/lambda_function.py index 37844d8..9c2d1ca 100644 --- a/lambdaFunctions/getRedditDataFunction/lambda_function.py +++ b/lambdaFunctions/getRedditDataFunction/lambda_function.py @@ -37,25 +37,25 @@ def lambda_handler(event, context): schema = tableDefinition.schema topN = 25 view = 'rising' - risingData = ru.getRedditData(reddit=reddit, subreddit=subreddit, view=view, schema=schema, topN=topN) - risingData = ru.deduplicateRedditData(risingData) + risingData = ru.get_reddit_data(reddit=reddit, subreddit=subreddit, view=view, schema=schema, top_n=topN) + risingData = ru.deduplicate_reddit_data(risingData) # Push to DynamoDB tableName = f"{view}-{os.environ['ENV']}" - risingTable = ru.getTable(tableName, dynamodb_resource) - ru.batchWriter(risingTable, risingData, schema) + risingTable = ru.get_table(tableName, dynamodb_resource) + ru.batch_writer(risingTable, risingData, schema) # Get Hot Reddit data print("\tGetting Hot Data") schema = tableDefinition.schema topN = 3 view = 'hot' - hotData = ru.getRedditData(reddit=reddit, subreddit=subreddit, view=view, schema=schema, topN=topN) - hotData = ru.deduplicateRedditData(hotData) + hotData = ru.get_reddit_data(reddit=reddit, subreddit=subreddit, view=view, schema=schema, top_n=topN) + hotData = ru.deduplicate_reddit_data(hotData) # Push to DynamoDB tableName = f"{view}-{os.environ['ENV']}" - hotTable = ru.getTable(tableName, dynamodb_resource) - ru.batchWriter(hotTable, hotData, schema) + hotTable = ru.get_table(tableName, dynamodb_resource) + ru.batch_writer(hotTable, hotData, schema) return 200 diff --git a/lambdaFunctions/getRedditDataFunction/redditUtils.py b/lambdaFunctions/getRedditDataFunction/redditUtils.py index d7ddac7..6ca6d49 100644 --- a/lambdaFunctions/getRedditDataFunction/redditUtils.py +++ b/lambdaFunctions/getRedditDataFunction/redditUtils.py @@ -1,52 +1,59 @@ -from datetime import datetime -from collections import namedtuple +from datetime import datetime, UTC +from collections import namedtuple, OrderedDict import tableDefinition import json from decimal import Decimal import pickle +from praw import Reddit -def saveTestReddit(reddit, filename): - pickle.dump(reddit, open(filename, 'wb')) - - -def getRedditData(reddit, subreddit, topN=25, view='rising', schema=tableDefinition.schema, time_filter=None, verbose=False): +def get_reddit_data( + reddit: Reddit, + subreddit: str, + top_n: int = 25, + view: str = 'rising', + schema: OrderedDict = tableDefinition.schema, + time_filter: str | None = None, + verbose: bool = False +): """ Uses PRAW to get data from reddit using defined parameters. Returns data in a list of row based data. :param reddit: PRAW reddit object :param subreddit: subreddit name - :param topN: Number of posts to return + :param top_n: Number of posts to return :param view: view to look at the subreddit. rising, top, hot :param schema: schema that describes the data. Dynamo is technically schema-less :param time_filter: range of time to look at the data. all, day, hour, month, week, year :param verbose: if True then prints more information :return: list[Row[schema]], Row is a namedtuple defined by the schema """ - assert topN <= 25 # some, like rising, cap out at 25 and this also is to limit data you're working with + assert top_n <= 25 # some, like rising, cap out at 25 and this also is to limit data you're working with assert view in {'rising', 'top' , 'hot'} - topN += 2 # increment by 2 because of sticky posts + top_n += 2 # increment by 2 because of sticky posts if view == 'top': assert time_filter in {"all", "day", "hour", "month", "week", "year"} - subredditObject = reddit.subreddit(subreddit) + subreddit_object = reddit.subreddit(subreddit) + top_n_posts = None if view == 'rising': - topNposts = subredditObject.rising(limit=topN) + top_n_posts = subreddit_object.rising(limit=top_n) elif view == 'hot': - topNposts = subredditObject.hot(limit=topN) + top_n_posts = subreddit_object.hot(limit=top_n) elif view == 'top': - topNposts = subredditObject.top(time_filter=time_filter, limit=topN) + top_n_posts = subreddit_object.top(time_filter=time_filter, limit=top_n) - now = datetime.utcnow().replace(tzinfo=None, microsecond=0) - columns = schema.keys() - Row = namedtuple("Row", columns) - dataCollected = [] - subscribers = subredditObject.subscribers - activeUsers = subredditObject.accounts_active - print(f'\tSubscribers: {subscribers}\n\tActive users: {activeUsers}') - for submission in topNposts: + now = datetime.now(UTC).replace(tzinfo=UTC, microsecond=0) + columns = list(schema.keys()) + Row = namedtuple(typename="Row", field_names=columns) + data_collected = [] + subscribers = subreddit_object.subscribers + active_users = subreddit_object.accounts_active + print(f'\tSubscribers: {subscribers}\n\tActive users: {active_users}') + + for submission in top_n_posts: if submission.stickied: continue # skip stickied posts - createdTSUTC = datetime.utcfromtimestamp(submission.created_utc) + createdTSUTC = datetime.fromtimestamp(submission.created_utc, UTC) timeSincePost = now - createdTSUTC timeElapsedMin = timeSincePost.seconds // 60 timeElapsedDays = timeSincePost.days @@ -60,19 +67,19 @@ def getRedditData(reddit, subreddit, topN=25, view='rising', schema=tableDefinit gildings = submission.gildings numGildings = sum(gildings.values()) row = Row( - postId=postId, subreddit=subreddit, subscribers=subscribers, activeUsers=activeUsers, + postId=postId, subreddit=subreddit, subscribers=subscribers, activeUsers=active_users, title=title, createdTSUTC=str(createdTSUTC), timeElapsedMin=timeElapsedMin, score=score, numComments=numComments, upvoteRatio=upvoteRatio, numGildings=numGildings, loadTSUTC=str(now), loadDateUTC=str(now.date()), loadTimeUTC=str(now.time())) - dataCollected.append(row) + data_collected.append(row) if verbose: print(row) print() - return dataCollected[:topN-2] + return data_collected -def deduplicateRedditData(data): +def deduplicate_reddit_data(data): """ Deduplicates the reddit data. Sometimes there are duplicate keys which throws an error when writing to dynamo. It is unclear why this happens but I suspect it is an issue with PRAW. @@ -92,7 +99,7 @@ def deduplicateRedditData(data): return newData -def getTable(tableName, dynamodb_resource): +def get_table(tableName, dynamodb_resource): table = dynamodb_resource.Table(tableName) # Print out some data about the table. @@ -100,7 +107,7 @@ def getTable(tableName, dynamodb_resource): return table -def batchWriter(table, data, schema): +def batch_writer(table, data, schema): """ https://boto3.amazonaws.com/v1/documentation/api/latest/guide/dynamodb.html#batch-writing I didn't bother with dealing with duplicates because shouldn't be a problem with this type of data diff --git a/lambdaFunctions/getRedditDataFunction/test_lambda.py b/lambdaFunctions/getRedditDataFunction/test_lambda.py index b618aa2..2791da8 100644 --- a/lambdaFunctions/getRedditDataFunction/test_lambda.py +++ b/lambdaFunctions/getRedditDataFunction/test_lambda.py @@ -1,3 +1,4 @@ +from datetime import datetime, UTC, timedelta import pytest import redditUtils as ru import praw @@ -9,51 +10,93 @@ THIS_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.join(THIS_DIR, '../../')) import viral_reddit_posts_utils.configUtils as cu -import pickle from moto import mock_dynamodb +from unittest.mock import patch, Mock +from dataclasses import dataclass IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" +HIT_REDDIT = False # set this to true if you want to test on realtime reddit data +PATH_OF_THIS_FILE = os.path.dirname(os.path.abspath(__file__)) @pytest.fixture(scope='module') -def cfg(): - cfg_file = cu.findConfig() - cfg = cu.parseConfig(cfg_file) - return cfg - - -@pytest.fixture(scope='module') -def reddit(cfg): - if IN_GITHUB_ACTIONS: - pass - redditcfg = cfg['reddit_api'] - return praw.Reddit( - client_id=f"{redditcfg['CLIENTID']}", - client_secret=f"{redditcfg['CLIENTSECRET']}", - password=f"{redditcfg['PASSWORD']}", - user_agent=f"Post Extraction (by u/{redditcfg['USERNAME']})", - username=f"{redditcfg['USERNAME']}", - ) - - -def test_getRedditData(reddit): +def reddit() -> praw.Reddit: + if not HIT_REDDIT: # just load the fake config + cfg_file = os.path.join(PATH_OF_THIS_FILE, "../../example_reddit.cfg") + else: # Try to get the real config file + try: + cfg_file = cu.findConfig() + except RuntimeError as e: + print(e) + cfg_file = os.path.join(PATH_OF_THIS_FILE, "../../example_reddit.cfg") + print(f"{cfg_file=}") + cfg = cu.parseConfig(cfg_file) + reddit_cfg = cfg['reddit_api'] + return praw.Reddit( + client_id=f"{reddit_cfg['CLIENTID']}", + client_secret=f"{reddit_cfg['CLIENTSECRET']}", + password=f"{reddit_cfg['PASSWORD']}", + user_agent=f"Post Extraction (by u/{reddit_cfg['USERNAME']})", + username=f"{reddit_cfg['USERNAME']}", + ) + + +@dataclass +class SubredditSample(): + subscribers = 10000000 + accounts_active = 10000 + + @dataclass + class SampleRisingSubmission(): + # set created time 15 min before now so it gets filtered into selected data + created_utc = int((datetime.now(UTC).replace(tzinfo=UTC, microsecond=0) - timedelta(minutes=15)).timestamp()) + stickied = False + id = '1c3dwli' + title = 'My son and my ferret. 😂' + score = 28 + num_comments = 1 + upvote_ratio = 0.86 + gildings = {} + + @staticmethod + def rising_generator(): + yield SubredditSample.SampleRisingSubmission + + @staticmethod + def rising(limit): + generator = SubredditSample.rising_generator() + return generator + + +@patch(target="praw.models.helpers.SubredditHelper.__call__", return_value = SubredditSample) +def test_get_reddit_data( + mock_subreddit:Mock, + reddit: praw.Reddit +): subreddit = "pics" - ru.getRedditData( + data_collected = ru.get_reddit_data( reddit, subreddit, - topN=25, + top_n=25, view='rising', schema=tableDefinition.schema, time_filter=None, - verbose=True) + verbose=True + ) + if not HIT_REDDIT: # fake data + row = data_collected[0] + assert row.subscribers == 10000000 + assert row.activeUsers == 10000 + assert row.title == 'My son and my ferret. 😂' + assert row.postId == '1c3dwli' @pytest.fixture(scope='module') -def duplicatedData(): +def duplicated_data(): schema = tableDefinition.schema - columns = schema.keys() - Row = namedtuple("Row", columns) + columns = list(schema.keys()) + Row = namedtuple(typename="Row", field_names=columns) # these are identical examples except one has a later loadTSUTC return [ Row(subscribers=10000000, activeUsers=10000, @@ -67,17 +110,15 @@ def duplicatedData(): ] -def test_deduplicateRedditData(duplicatedData): - newData = ru.deduplicateRedditData(duplicatedData) - assert len(newData) == 1 +def test_deduplicate_reddit_data(duplicated_data): + new_data = ru.deduplicate_reddit_data(duplicated_data) + assert len(new_data) == 1 print("test_deduplicateRedditData complete") @mock_dynamodb class TestBatchWriter: - - - def classSetUp(self): + def class_set_up(self): """ If we left this at top level of the class then it won't be skipped by `skip` and `skipif` furthermore we can't have __init__ in a Test Class, so this is called prior to each test @@ -85,16 +126,16 @@ def classSetUp(self): """ dynamodb = boto3.resource('dynamodb', region_name='us-east-2') # create table and write to sample data - tableName = 'rising' - td = tableDefinition.getTableDefinition(tableName=tableName) + table_name = 'rising' + td = tableDefinition.getTableDefinition(tableName=table_name) self.testTable = dynamodb.create_table(**td) self.schema = tableDefinition.schema self.columns = self.schema.keys() - self.Row = namedtuple("Row", self.columns) + self.Row = namedtuple(typename="Row", field_names=self.columns) @pytest.mark.xfail(reason="BatchWriter fails on duplicate keys. This might xpass, possibly a fault in mock object.") - def test_duplicateData(self): - self.classSetUp() + def test_duplicate_data(self): + self.class_set_up() testTable = self.testTable schema = self.schema Row=self.Row @@ -107,12 +148,11 @@ def test_duplicateData(self): subreddit='pics', title='Magnolia tree blooming in my friends yard', createdTSUTC='2023-04-30 04:19:43', timeElapsedMin=44, score=3, numComments=0, upvoteRatio=1.0, numGildings=0) ] - from redditUtils import batchWriter - batchWriter(table=testTable, data=data, schema=schema) + ru.batch_writer(table=testTable, data=data, schema=schema) print("duplicateDataTester test complete") - def test_uniqueData(self): - self.classSetUp() + def test_unique_data(self): + self.class_set_up() testTable = self.testTable schema = self.schema Row = self.Row @@ -127,13 +167,12 @@ def test_uniqueData(self): subreddit='pics', title='A piece of wood sticking up in front of a fire.', createdTSUTC='2023-04-30 04:29:23', timeElapsedMin=34, score=0, numComments=0, upvoteRatio=0.4, numGildings=0) ] - from redditUtils import batchWriter - batchWriter(table=testTable, data=data, schema=schema) + ru.batch_writer(table=testTable, data=data, schema=schema) print("uniqueDataTester test complete") - def test_diffPrimaryIndexSameSecondIndex(self): - self.classSetUp() - testTable = self.testTable + def test_diff_primary_index_same_second_index(self): + self.class_set_up() + test_table = self.testTable schema = self.schema Row = self.Row @@ -147,6 +186,6 @@ def test_diffPrimaryIndexSameSecondIndex(self): subreddit='pics', title='Magnolia tree blooming in my friends yard', createdTSUTC='2023-04-30 04:19:43', timeElapsedMin=44, score=3, numComments=0, upvoteRatio=1.0, numGildings=0) ] - from redditUtils import batchWriter - batchWriter(table=testTable, data=data, schema=schema) + + ru.batch_writer(table=test_table, data=data, schema=schema) print("diffPrimaryIndexSameSecondIndexTester test complete")