Skip to content

Commit

Permalink
fixing tests, updating readme, updating package setup
Browse files Browse the repository at this point in the history
  • Loading branch information
kennethjmyers committed Apr 25, 2024
1 parent dac8053 commit a992081
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 37 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
![Python](https://img.shields.io/badge/python-3.8.15-blue.svg)
![Python](https://img.shields.io/badge/python-3.12.3-blue.svg)

# Viral Reddit Posts Model

Expand All @@ -15,8 +15,9 @@ The purpose of this repo is to:
2. Installs - see the [prerequisites section on this page](https://developer.hashicorp.com/terraform/tutorials/aws-get-started/aws-build#prerequisites) for additional information, the steps are essentially:
1. Install Terraform CLI
2. Install AWS CLI and run `aws configure` and enter in your aws credentials.
3. JDK 17 installed (8, 11 or 17 are compatible with spark 3.4.0)
3. Clone this repository
4. You can run the tests locally yourself by doing the following (it is recommended that you manage your python environments with something like [asdf](https://asdf-vm.com/) and use python==3.8.15 as your local runtime):
4. You can run the tests locally yourself by doing the following (it is recommended that you manage your python environments with something like [asdf](https://asdf-vm.com/) and use python==3.12.3 as your local runtime):

```sh
python -m venv venv # this sets up a local virtual env using the current python runtime
Expand Down
6 changes: 3 additions & 3 deletions model/ModelETL.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
import os
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(THIS_DIR, '../'))
import viral_reddit_posts_utils.configUtils as cu
import viral_reddit_posts_utils.config_utils as cu


# Forcing Timezone keeps things consistent with running on aws and without it timestamps get additional
# timezone conversions when writing to parquet. Setting spark timezone was not enough to fix this
os.environ['TZ'] = 'UTC'

cfg_file = cu.findConfig()
cfg = cu.parseConfig(cfg_file)
cfg_file = cu.find_config()
cfg = cu.parse_config(cfg_file)

spark = (
SparkSession
Expand Down
20 changes: 15 additions & 5 deletions model/PredictETL.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from datetime import datetime, timedelta
from boto3.dynamodb.conditions import Key, Attr
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
import pandas as pd
import sqlUtils as su
import sys
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(THIS_DIR, '../'))
import viral_reddit_posts_utils.configUtils as cu
import viral_reddit_posts_utils.config_utils as cu


os.environ['TZ'] = 'UTC'
Expand Down Expand Up @@ -61,6 +62,15 @@ def extract(self):

self.postIdData = modelUtils.getPostIdSparkDataFrame(self.spark, risingTable, postsOfInterest, chunkSize=100)

# type issue https://stackoverflow.com/questions/76072664/convert-pyspark-dataframe-to-pandas-dataframe-fails-on-timestamp-column
self.postIdData = (
self.postIdData
.withColumn("loadDateUTC", F.date_format("loadDateUTC", "yyyy-MM-dd"))
.withColumn("loadTimeUTC", F.date_format("loadTimeUTC", "HH:mm:ss"))
.withColumn("loadTSUTC", F.date_format("loadTSUTC", "yyyy-MM-dd HH:mm:ss"))
.withColumn("createdTSUTC", F.date_format("createdTSUTC", "yyyy-MM-dd HH:mm:ss"))
)

pandasTestDf = self.postIdData.limit(5).toPandas()
print(pandasTestDf.to_string())
print("Finished gathering Rising Data.")
Expand Down Expand Up @@ -116,7 +126,7 @@ def load(self, data, tableName):
engine = self.engine
data = data.set_index(['postId'])
with engine.connect() as conn:
result = su.upsert_df(df=data, table_name=tableName, engine=conn)
result = su.upsert_df(df=data, table_name=tableName, engine=conn.connection)
print("Finished writing to postgres")
return

Expand Down Expand Up @@ -147,7 +157,7 @@ def filterPreviousViralData(self, data):
postIds = list(data['postId'])
sql = f"""select "postId", "stepUp", 1 as "matchFound" from public."scoredData" where "postId" in ('{"','".join(postIds)}') and "stepUp" = 1"""
with engine.connect() as conn:
result = pd.read_sql(sql=sql, con=conn)
result = pd.read_sql(sql=sql, con=conn.connection)
# join data together
joinedData = pd.merge(data, result, on=['postId', 'stepUp'], how='left')
# filter out where match found
Expand Down Expand Up @@ -199,9 +209,9 @@ def notifyUserAboutViralPosts(self, viralData):
threshold = 0.29412 # eventually will probably put this in its own config file, maybe it differs per subreddit
# modelName = 'models/Reddit_model_20230503-235329_GBM.sav'

# cfg_file = cu.findConfig()
# cfg_file = cu.find_config()
cfg_file = 's3://data-kennethmyers/reddit.cfg'
cfg = cu.parseConfig(cfg_file)
cfg = cu.parse_config(cfg_file)

spark = (
SparkSession
Expand Down
6 changes: 3 additions & 3 deletions model/test_discordUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import os
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(THIS_DIR, '../'))
import viral_reddit_posts_utils.configUtils as cu
import viral_reddit_posts_utils.config_utils as cu
import responses


@pytest.fixture(scope='module')
def cfg():
cfg_file = cu.findConfig()
cfg = cu.parseConfig(cfg_file)
cfg_file = cu.find_config()
cfg = cu.parse_config(cfg_file)
return cfg


Expand Down
29 changes: 14 additions & 15 deletions model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@
import os
from datetime import datetime, timedelta
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
import schema
import boto3
from moto import mock_dynamodb
import sys
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(THIS_DIR, '../lambdaFunctions/getRedditDataFunction/'))
sys.path.append(os.path.join(THIS_DIR, '../'))
import viral_reddit_posts_utils.configUtils as cu
import tableDefinition
import viral_reddit_posts_utils.config_utils as cu
from lambda_functions.get_reddit_data_function import table_definition
import pandas as pd
import json
from decimal import Decimal
Expand All @@ -22,10 +19,10 @@

IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
os.environ['TZ'] = 'UTC'

THIS_DIR = os.path.dirname(os.path.abspath(__file__))

@pytest.fixture(scope='module')
def sampleRisingData():
def sample_rising_data():
d = pd.read_csv(os.path.join(THIS_DIR, 'test_data.csv'))
# we need to change some of the values here so that they can be found by the extract method
# particularly the load dates have to be within the last hour
Expand All @@ -49,7 +46,7 @@ def sampleRisingData():
# dynamodb = boto3.resource('dynamodb')
# # create table and write to sample data
# tableName = 'rising'
# td = tableDefinition.getTableDefinition(tableName=tableName)
# td = table_definition.gettable_definition(tableName=tableName)
# table = dynamodb.create_table(**td)
# with table.batch_writer() as batch:
# for item in json.loads(sampleRisingData, parse_float=Decimal): # for each row obtained
Expand Down Expand Up @@ -102,8 +99,8 @@ def modelName():

@pytest.fixture(scope='module')
def cfg():
cfg_file = cu.findConfig()
cfg = cu.parseConfig(cfg_file)
cfg_file = cu.find_config()
cfg = cu.parse_config(cfg_file)
return cfg


Expand All @@ -119,14 +116,14 @@ def engine(cfg):

# as I mentioned earlier, I'd like to change this so the dynamodb resource is a fixture but it kept throwing errors
@mock_dynamodb
def test_extract(sampleRisingData, cfg, engine, model, modelName, spark, threshold):
def test_extract(sample_rising_data, cfg, engine, model, modelName, spark, threshold):
dynamodb = boto3.resource('dynamodb', region_name='us-east-2')
# create table and write to sample data
tableName = 'rising'
td = tableDefinition.getTableDefinition(tableName=tableName)
td = table_definition.get_table_definition(tableName=tableName)
table = dynamodb.create_table(**td)
with table.batch_writer() as batch:
for item in json.loads(sampleRisingData, parse_float=Decimal): # for each row obtained
for item in json.loads(sample_rising_data, parse_float=Decimal): # for each row obtained
batch.put_item(
Item=item # json.loads(item, parse_float=Decimal) # helps with parsing float to Decimal
)
Expand Down Expand Up @@ -208,7 +205,9 @@ def aggDataDf(spark):
'maxScoreGrowth21_40m41_60m': 1.1,
'maxNumCommentsGrowth21_40m41_60m': 0.5,
}]
return spark.createDataFrame(testAggData, schema.aggDataSparkSchema).toPandas()
df = spark.createDataFrame(testAggData, schema.aggDataSparkSchema)
df = df.withColumn("createdTSUTC", F.date_format("createdTSUTC", "yyyy-MM-dd HH:mm:ss"))
return df.toPandas()


def test_createPredictions(aggDataDf, pipeline):
Expand Down
6 changes: 3 additions & 3 deletions patches/rds/newColumns.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(THIS_DIR, '../../'))
sys.path.append(os.path.join(THIS_DIR, '../../model/'))
import viral_reddit_posts_utils.configUtils as cu
import viral_reddit_posts_utils.config_utils as cu
import sqlUtils as su


cfg_file = cu.findConfig()
cfg = cu.parseConfig(cfg_file)
cfg_file = cu.find_config()
cfg = cu.parse_config(cfg_file)
engine = su.makeEngine(cfg)

newColumnsStr = """
Expand Down
13 changes: 7 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,22 @@ dynamic = ["version"]

dependencies = [
"boto3==1.26.117",
"matplotlib==3.3.4",
"numpy==1.21.6", # required by pyspark
"pandas==1.3", # 1.3 at least needed for M1 Mac, lower version required by pyspark
"matplotlib==3.8",
"numpy==1.26",
"pandas==2.2.2", # 1.3 at least needed for M1 Mac
"pg8000==1.29.4", # this was easier to pip install than psycopg2
"pyarrow==15.0.2", # don't use low versions which pin lower versions of numpy that break on M1 Mac
"pyspark==3.4.0", # using this version because py37 deprecated in pyspark 3.4.0
"pyspark==3.4.0",
"requests==2.31.0",
"scikit-learn==1.0.2",
"scikit-learn==1.4.2",
"seaborn==0.11.2",
"shap==0.41.0",
"sqlalchemy==1.4.46", # originally tried 2.0.10, but this was incompatible with old versions of pandas https://stackoverflow.com/a/75282604/5034651,
"viral_reddit_posts_utils @ git+https://github.com/ViralRedditPosts/Utils.git@main",
"Reddit-Scraping @ git+https://github.com/ViralRedditPosts/Reddit-Scraping.git@main",
]

requires-python = "== 3.8.15"
requires-python = "== 3.12.3"

authors = [
{name = "Kenneth Myers", email = "[email protected]"},
Expand Down

0 comments on commit a992081

Please sign in to comment.