Skip to content

Commit

Permalink
improved testing
Browse files Browse the repository at this point in the history
  • Loading branch information
kennethjmyers committed Apr 14, 2024
1 parent 9048982 commit c493873
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 87 deletions.
16 changes: 8 additions & 8 deletions lambdaFunctions/getRedditDataFunction/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,25 @@ def lambda_handler(event, context):
schema = tableDefinition.schema
topN = 25
view = 'rising'
risingData = ru.getRedditData(reddit=reddit, subreddit=subreddit, view=view, schema=schema, topN=topN)
risingData = ru.deduplicateRedditData(risingData)
risingData = ru.get_reddit_data(reddit=reddit, subreddit=subreddit, view=view, schema=schema, top_n=topN)
risingData = ru.deduplicate_reddit_data(risingData)

# Push to DynamoDB
tableName = f"{view}-{os.environ['ENV']}"
risingTable = ru.getTable(tableName, dynamodb_resource)
ru.batchWriter(risingTable, risingData, schema)
risingTable = ru.get_table(tableName, dynamodb_resource)
ru.batch_writer(risingTable, risingData, schema)

# Get Hot Reddit data
print("\tGetting Hot Data")
schema = tableDefinition.schema
topN = 3
view = 'hot'
hotData = ru.getRedditData(reddit=reddit, subreddit=subreddit, view=view, schema=schema, topN=topN)
hotData = ru.deduplicateRedditData(hotData)
hotData = ru.get_reddit_data(reddit=reddit, subreddit=subreddit, view=view, schema=schema, top_n=topN)
hotData = ru.deduplicate_reddit_data(hotData)

# Push to DynamoDB
tableName = f"{view}-{os.environ['ENV']}"
hotTable = ru.getTable(tableName, dynamodb_resource)
ru.batchWriter(hotTable, hotData, schema)
hotTable = ru.get_table(tableName, dynamodb_resource)
ru.batch_writer(hotTable, hotData, schema)

return 200
65 changes: 36 additions & 29 deletions lambdaFunctions/getRedditDataFunction/redditUtils.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,59 @@
from datetime import datetime
from collections import namedtuple
from datetime import datetime, UTC
from collections import namedtuple, OrderedDict
import tableDefinition
import json
from decimal import Decimal
import pickle
from praw import Reddit


def saveTestReddit(reddit, filename):
pickle.dump(reddit, open(filename, 'wb'))


def getRedditData(reddit, subreddit, topN=25, view='rising', schema=tableDefinition.schema, time_filter=None, verbose=False):
def get_reddit_data(
reddit: Reddit,
subreddit: str,
top_n: int = 25,
view: str = 'rising',
schema: OrderedDict = tableDefinition.schema,
time_filter: str | None = None,
verbose: bool = False
):
"""
Uses PRAW to get data from reddit using defined parameters. Returns data in a list of row based data.
:param reddit: PRAW reddit object
:param subreddit: subreddit name
:param topN: Number of posts to return
:param top_n: Number of posts to return
:param view: view to look at the subreddit. rising, top, hot
:param schema: schema that describes the data. Dynamo is technically schema-less
:param time_filter: range of time to look at the data. all, day, hour, month, week, year
:param verbose: if True then prints more information
:return: list[Row[schema]], Row is a namedtuple defined by the schema
"""
assert topN <= 25 # some, like rising, cap out at 25 and this also is to limit data you're working with
assert top_n <= 25 # some, like rising, cap out at 25 and this also is to limit data you're working with
assert view in {'rising', 'top' , 'hot'}
topN += 2 # increment by 2 because of sticky posts
top_n += 2 # increment by 2 because of sticky posts
if view == 'top':
assert time_filter in {"all", "day", "hour", "month", "week", "year"}
subredditObject = reddit.subreddit(subreddit)
subreddit_object = reddit.subreddit(subreddit)
top_n_posts = None
if view == 'rising':
topNposts = subredditObject.rising(limit=topN)
top_n_posts = subreddit_object.rising(limit=top_n)
elif view == 'hot':
topNposts = subredditObject.hot(limit=topN)
top_n_posts = subreddit_object.hot(limit=top_n)
elif view == 'top':
topNposts = subredditObject.top(time_filter=time_filter, limit=topN)
top_n_posts = subreddit_object.top(time_filter=time_filter, limit=top_n)

now = datetime.utcnow().replace(tzinfo=None, microsecond=0)
columns = schema.keys()
Row = namedtuple("Row", columns)
dataCollected = []
subscribers = subredditObject.subscribers
activeUsers = subredditObject.accounts_active
print(f'\tSubscribers: {subscribers}\n\tActive users: {activeUsers}')
for submission in topNposts:
now = datetime.now(UTC).replace(tzinfo=UTC, microsecond=0)
columns = list(schema.keys())
Row = namedtuple(typename="Row", field_names=columns)
data_collected = []
subscribers = subreddit_object.subscribers
active_users = subreddit_object.accounts_active
print(f'\tSubscribers: {subscribers}\n\tActive users: {active_users}')

for submission in top_n_posts:
if submission.stickied:
continue # skip stickied posts
createdTSUTC = datetime.utcfromtimestamp(submission.created_utc)
createdTSUTC = datetime.fromtimestamp(submission.created_utc, UTC)
timeSincePost = now - createdTSUTC
timeElapsedMin = timeSincePost.seconds // 60
timeElapsedDays = timeSincePost.days
Expand All @@ -60,19 +67,19 @@ def getRedditData(reddit, subreddit, topN=25, view='rising', schema=tableDefinit
gildings = submission.gildings
numGildings = sum(gildings.values())
row = Row(
postId=postId, subreddit=subreddit, subscribers=subscribers, activeUsers=activeUsers,
postId=postId, subreddit=subreddit, subscribers=subscribers, activeUsers=active_users,
title=title, createdTSUTC=str(createdTSUTC),
timeElapsedMin=timeElapsedMin, score=score, numComments=numComments,
upvoteRatio=upvoteRatio, numGildings=numGildings,
loadTSUTC=str(now), loadDateUTC=str(now.date()), loadTimeUTC=str(now.time()))
dataCollected.append(row)
data_collected.append(row)
if verbose:
print(row)
print()
return dataCollected[:topN-2]
return data_collected


def deduplicateRedditData(data):
def deduplicate_reddit_data(data):
"""
Deduplicates the reddit data. Sometimes there are duplicate keys which throws an error
when writing to dynamo. It is unclear why this happens but I suspect it is an issue with PRAW.
Expand All @@ -92,15 +99,15 @@ def deduplicateRedditData(data):
return newData


def getTable(tableName, dynamodb_resource):
def get_table(tableName, dynamodb_resource):
table = dynamodb_resource.Table(tableName)

# Print out some data about the table.
print(f"Item count in table: {table.item_count}") # this only updates every 6 hours
return table


def batchWriter(table, data, schema):
def batch_writer(table, data, schema):
"""
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/dynamodb.html#batch-writing
I didn't bother with dealing with duplicates because shouldn't be a problem with this type of data
Expand Down
139 changes: 89 additions & 50 deletions lambdaFunctions/getRedditDataFunction/test_lambda.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime, UTC, timedelta
import pytest
import redditUtils as ru
import praw
Expand All @@ -9,51 +10,93 @@
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(THIS_DIR, '../../'))
import viral_reddit_posts_utils.configUtils as cu
import pickle
from moto import mock_dynamodb
from unittest.mock import patch, Mock
from dataclasses import dataclass


IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
HIT_REDDIT = False # set this to true if you want to test on realtime reddit data
PATH_OF_THIS_FILE = os.path.dirname(os.path.abspath(__file__))


@pytest.fixture(scope='module')
def cfg():
cfg_file = cu.findConfig()
cfg = cu.parseConfig(cfg_file)
return cfg


@pytest.fixture(scope='module')
def reddit(cfg):
if IN_GITHUB_ACTIONS:
pass
redditcfg = cfg['reddit_api']
return praw.Reddit(
client_id=f"{redditcfg['CLIENTID']}",
client_secret=f"{redditcfg['CLIENTSECRET']}",
password=f"{redditcfg['PASSWORD']}",
user_agent=f"Post Extraction (by u/{redditcfg['USERNAME']})",
username=f"{redditcfg['USERNAME']}",
)


def test_getRedditData(reddit):
def reddit() -> praw.Reddit:
if not HIT_REDDIT: # just load the fake config
cfg_file = os.path.join(PATH_OF_THIS_FILE, "../../example_reddit.cfg")
else: # Try to get the real config file
try:
cfg_file = cu.findConfig()
except RuntimeError as e:
print(e)
cfg_file = os.path.join(PATH_OF_THIS_FILE, "../../example_reddit.cfg")
print(f"{cfg_file=}")
cfg = cu.parseConfig(cfg_file)
reddit_cfg = cfg['reddit_api']
return praw.Reddit(
client_id=f"{reddit_cfg['CLIENTID']}",
client_secret=f"{reddit_cfg['CLIENTSECRET']}",
password=f"{reddit_cfg['PASSWORD']}",
user_agent=f"Post Extraction (by u/{reddit_cfg['USERNAME']})",
username=f"{reddit_cfg['USERNAME']}",
)


@dataclass
class SubredditSample():
subscribers = 10000000
accounts_active = 10000

@dataclass
class SampleRisingSubmission():
# set created time 15 min before now so it gets filtered into selected data
created_utc = int((datetime.now(UTC).replace(tzinfo=UTC, microsecond=0) - timedelta(minutes=15)).timestamp())
stickied = False
id = '1c3dwli'
title = 'My son and my ferret. 😂'
score = 28
num_comments = 1
upvote_ratio = 0.86
gildings = {}

@staticmethod
def rising_generator():
yield SubredditSample.SampleRisingSubmission

@staticmethod
def rising(limit):
generator = SubredditSample.rising_generator()
return generator


@patch(target="praw.models.helpers.SubredditHelper.__call__", return_value = SubredditSample)
def test_get_reddit_data(
mock_subreddit:Mock,
reddit: praw.Reddit
):
subreddit = "pics"
ru.getRedditData(
data_collected = ru.get_reddit_data(
reddit,
subreddit,
topN=25,
top_n=25,
view='rising',
schema=tableDefinition.schema,
time_filter=None,
verbose=True)
verbose=True
)
if not HIT_REDDIT: # fake data
row = data_collected[0]
assert row.subscribers == 10000000
assert row.activeUsers == 10000
assert row.title == 'My son and my ferret. 😂'
assert row.postId == '1c3dwli'


@pytest.fixture(scope='module')
def duplicatedData():
def duplicated_data():
schema = tableDefinition.schema
columns = schema.keys()
Row = namedtuple("Row", columns)
columns = list(schema.keys())
Row = namedtuple(typename="Row", field_names=columns)
# these are identical examples except one has a later loadTSUTC
return [
Row(subscribers=10000000, activeUsers=10000,
Expand All @@ -67,34 +110,32 @@ def duplicatedData():
]


def test_deduplicateRedditData(duplicatedData):
newData = ru.deduplicateRedditData(duplicatedData)
assert len(newData) == 1
def test_deduplicate_reddit_data(duplicated_data):
new_data = ru.deduplicate_reddit_data(duplicated_data)
assert len(new_data) == 1
print("test_deduplicateRedditData complete")


@mock_dynamodb
class TestBatchWriter:


def classSetUp(self):
def class_set_up(self):
"""
If we left this at top level of the class then it won't be skipped by `skip` and `skipif`
furthermore we can't have __init__ in a Test Class, so this is called prior to each test
:return:
"""
dynamodb = boto3.resource('dynamodb', region_name='us-east-2')
# create table and write to sample data
tableName = 'rising'
td = tableDefinition.getTableDefinition(tableName=tableName)
table_name = 'rising'
td = tableDefinition.getTableDefinition(tableName=table_name)
self.testTable = dynamodb.create_table(**td)
self.schema = tableDefinition.schema
self.columns = self.schema.keys()
self.Row = namedtuple("Row", self.columns)
self.Row = namedtuple(typename="Row", field_names=self.columns)

@pytest.mark.xfail(reason="BatchWriter fails on duplicate keys. This might xpass, possibly a fault in mock object.")
def test_duplicateData(self):
self.classSetUp()
def test_duplicate_data(self):
self.class_set_up()
testTable = self.testTable
schema = self.schema
Row=self.Row
Expand All @@ -107,12 +148,11 @@ def test_duplicateData(self):
subreddit='pics', title='Magnolia tree blooming in my friends yard', createdTSUTC='2023-04-30 04:19:43',
timeElapsedMin=44, score=3, numComments=0, upvoteRatio=1.0, numGildings=0)
]
from redditUtils import batchWriter
batchWriter(table=testTable, data=data, schema=schema)
ru.batch_writer(table=testTable, data=data, schema=schema)
print("duplicateDataTester test complete")

def test_uniqueData(self):
self.classSetUp()
def test_unique_data(self):
self.class_set_up()
testTable = self.testTable
schema = self.schema
Row = self.Row
Expand All @@ -127,13 +167,12 @@ def test_uniqueData(self):
subreddit='pics', title='A piece of wood sticking up in front of a fire.', createdTSUTC='2023-04-30 04:29:23',
timeElapsedMin=34, score=0, numComments=0, upvoteRatio=0.4, numGildings=0)
]
from redditUtils import batchWriter
batchWriter(table=testTable, data=data, schema=schema)
ru.batch_writer(table=testTable, data=data, schema=schema)
print("uniqueDataTester test complete")

def test_diffPrimaryIndexSameSecondIndex(self):
self.classSetUp()
testTable = self.testTable
def test_diff_primary_index_same_second_index(self):
self.class_set_up()
test_table = self.testTable
schema = self.schema
Row = self.Row

Expand All @@ -147,6 +186,6 @@ def test_diffPrimaryIndexSameSecondIndex(self):
subreddit='pics', title='Magnolia tree blooming in my friends yard', createdTSUTC='2023-04-30 04:19:43',
timeElapsedMin=44, score=3, numComments=0, upvoteRatio=1.0, numGildings=0)
]
from redditUtils import batchWriter
batchWriter(table=testTable, data=data, schema=schema)

ru.batch_writer(table=test_table, data=data, schema=schema)
print("diffPrimaryIndexSameSecondIndexTester test complete")

0 comments on commit c493873

Please sign in to comment.