From a992081cbf484ded7fdc8508e47e63bbda19ef11 Mon Sep 17 00:00:00 2001 From: Kenneth Myers Date: Thu, 25 Apr 2024 04:05:29 -0400 Subject: [PATCH] fixing tests, updating readme, updating package setup --- README.md | 5 +++-- model/ModelETL.py | 6 +++--- model/PredictETL.py | 20 +++++++++++++++----- model/test_discordUtils.py | 6 +++--- model/test_model.py | 29 ++++++++++++++--------------- patches/rds/newColumns.py | 6 +++--- pyproject.toml | 13 +++++++------ 7 files changed, 48 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 6f87b7f..50d3f09 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -![Python](https://img.shields.io/badge/python-3.8.15-blue.svg) +![Python](https://img.shields.io/badge/python-3.12.3-blue.svg) # Viral Reddit Posts Model @@ -15,8 +15,9 @@ The purpose of this repo is to: 2. Installs - see the [prerequisites section on this page](https://developer.hashicorp.com/terraform/tutorials/aws-get-started/aws-build#prerequisites) for additional information, the steps are essentially: 1. Install Terraform CLI 2. Install AWS CLI and run `aws configure` and enter in your aws credentials. + 3. JDK 17 installed (8, 11 or 17 are compatible with spark 3.4.0) 3. Clone this repository -4. You can run the tests locally yourself by doing the following (it is recommended that you manage your python environments with something like [asdf](https://asdf-vm.com/) and use python==3.8.15 as your local runtime): +4. You can run the tests locally yourself by doing the following (it is recommended that you manage your python environments with something like [asdf](https://asdf-vm.com/) and use python==3.12.3 as your local runtime): ```sh python -m venv venv # this sets up a local virtual env using the current python runtime diff --git a/model/ModelETL.py b/model/ModelETL.py index cc29673..30a1aac 100755 --- a/model/ModelETL.py +++ b/model/ModelETL.py @@ -12,15 +12,15 @@ import os THIS_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.join(THIS_DIR, '../')) -import viral_reddit_posts_utils.configUtils as cu +import viral_reddit_posts_utils.config_utils as cu # Forcing Timezone keeps things consistent with running on aws and without it timestamps get additional # timezone conversions when writing to parquet. Setting spark timezone was not enough to fix this os.environ['TZ'] = 'UTC' -cfg_file = cu.findConfig() -cfg = cu.parseConfig(cfg_file) +cfg_file = cu.find_config() +cfg = cu.parse_config(cfg_file) spark = ( SparkSession diff --git a/model/PredictETL.py b/model/PredictETL.py index fd482e1..6a5e9b7 100755 --- a/model/PredictETL.py +++ b/model/PredictETL.py @@ -6,12 +6,13 @@ from datetime import datetime, timedelta from boto3.dynamodb.conditions import Key, Attr from pyspark.sql import SparkSession +import pyspark.sql.functions as F import pandas as pd import sqlUtils as su import sys THIS_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.join(THIS_DIR, '../')) -import viral_reddit_posts_utils.configUtils as cu +import viral_reddit_posts_utils.config_utils as cu os.environ['TZ'] = 'UTC' @@ -61,6 +62,15 @@ def extract(self): self.postIdData = modelUtils.getPostIdSparkDataFrame(self.spark, risingTable, postsOfInterest, chunkSize=100) + # type issue https://stackoverflow.com/questions/76072664/convert-pyspark-dataframe-to-pandas-dataframe-fails-on-timestamp-column + self.postIdData = ( + self.postIdData + .withColumn("loadDateUTC", F.date_format("loadDateUTC", "yyyy-MM-dd")) + .withColumn("loadTimeUTC", F.date_format("loadTimeUTC", "HH:mm:ss")) + .withColumn("loadTSUTC", F.date_format("loadTSUTC", "yyyy-MM-dd HH:mm:ss")) + .withColumn("createdTSUTC", F.date_format("createdTSUTC", "yyyy-MM-dd HH:mm:ss")) + ) + pandasTestDf = self.postIdData.limit(5).toPandas() print(pandasTestDf.to_string()) print("Finished gathering Rising Data.") @@ -116,7 +126,7 @@ def load(self, data, tableName): engine = self.engine data = data.set_index(['postId']) with engine.connect() as conn: - result = su.upsert_df(df=data, table_name=tableName, engine=conn) + result = su.upsert_df(df=data, table_name=tableName, engine=conn.connection) print("Finished writing to postgres") return @@ -147,7 +157,7 @@ def filterPreviousViralData(self, data): postIds = list(data['postId']) sql = f"""select "postId", "stepUp", 1 as "matchFound" from public."scoredData" where "postId" in ('{"','".join(postIds)}') and "stepUp" = 1""" with engine.connect() as conn: - result = pd.read_sql(sql=sql, con=conn) + result = pd.read_sql(sql=sql, con=conn.connection) # join data together joinedData = pd.merge(data, result, on=['postId', 'stepUp'], how='left') # filter out where match found @@ -199,9 +209,9 @@ def notifyUserAboutViralPosts(self, viralData): threshold = 0.29412 # eventually will probably put this in its own config file, maybe it differs per subreddit # modelName = 'models/Reddit_model_20230503-235329_GBM.sav' - # cfg_file = cu.findConfig() + # cfg_file = cu.find_config() cfg_file = 's3://data-kennethmyers/reddit.cfg' - cfg = cu.parseConfig(cfg_file) + cfg = cu.parse_config(cfg_file) spark = ( SparkSession diff --git a/model/test_discordUtils.py b/model/test_discordUtils.py index 026e626..aa116aa 100644 --- a/model/test_discordUtils.py +++ b/model/test_discordUtils.py @@ -4,14 +4,14 @@ import os THIS_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.join(THIS_DIR, '../')) -import viral_reddit_posts_utils.configUtils as cu +import viral_reddit_posts_utils.config_utils as cu import responses @pytest.fixture(scope='module') def cfg(): - cfg_file = cu.findConfig() - cfg = cu.parseConfig(cfg_file) + cfg_file = cu.find_config() + cfg = cu.parse_config(cfg_file) return cfg diff --git a/model/test_model.py b/model/test_model.py index 0f7cca7..0450df0 100644 --- a/model/test_model.py +++ b/model/test_model.py @@ -5,15 +5,12 @@ import os from datetime import datetime, timedelta from pyspark.sql import SparkSession +import pyspark.sql.functions as F import schema import boto3 from moto import mock_dynamodb -import sys -THIS_DIR = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(os.path.join(THIS_DIR, '../lambdaFunctions/getRedditDataFunction/')) -sys.path.append(os.path.join(THIS_DIR, '../')) -import viral_reddit_posts_utils.configUtils as cu -import tableDefinition +import viral_reddit_posts_utils.config_utils as cu +from lambda_functions.get_reddit_data_function import table_definition import pandas as pd import json from decimal import Decimal @@ -22,10 +19,10 @@ IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" os.environ['TZ'] = 'UTC' - +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @pytest.fixture(scope='module') -def sampleRisingData(): +def sample_rising_data(): d = pd.read_csv(os.path.join(THIS_DIR, 'test_data.csv')) # we need to change some of the values here so that they can be found by the extract method # particularly the load dates have to be within the last hour @@ -49,7 +46,7 @@ def sampleRisingData(): # dynamodb = boto3.resource('dynamodb') # # create table and write to sample data # tableName = 'rising' -# td = tableDefinition.getTableDefinition(tableName=tableName) +# td = table_definition.gettable_definition(tableName=tableName) # table = dynamodb.create_table(**td) # with table.batch_writer() as batch: # for item in json.loads(sampleRisingData, parse_float=Decimal): # for each row obtained @@ -102,8 +99,8 @@ def modelName(): @pytest.fixture(scope='module') def cfg(): - cfg_file = cu.findConfig() - cfg = cu.parseConfig(cfg_file) + cfg_file = cu.find_config() + cfg = cu.parse_config(cfg_file) return cfg @@ -119,14 +116,14 @@ def engine(cfg): # as I mentioned earlier, I'd like to change this so the dynamodb resource is a fixture but it kept throwing errors @mock_dynamodb -def test_extract(sampleRisingData, cfg, engine, model, modelName, spark, threshold): +def test_extract(sample_rising_data, cfg, engine, model, modelName, spark, threshold): dynamodb = boto3.resource('dynamodb', region_name='us-east-2') # create table and write to sample data tableName = 'rising' - td = tableDefinition.getTableDefinition(tableName=tableName) + td = table_definition.get_table_definition(tableName=tableName) table = dynamodb.create_table(**td) with table.batch_writer() as batch: - for item in json.loads(sampleRisingData, parse_float=Decimal): # for each row obtained + for item in json.loads(sample_rising_data, parse_float=Decimal): # for each row obtained batch.put_item( Item=item # json.loads(item, parse_float=Decimal) # helps with parsing float to Decimal ) @@ -208,7 +205,9 @@ def aggDataDf(spark): 'maxScoreGrowth21_40m41_60m': 1.1, 'maxNumCommentsGrowth21_40m41_60m': 0.5, }] - return spark.createDataFrame(testAggData, schema.aggDataSparkSchema).toPandas() + df = spark.createDataFrame(testAggData, schema.aggDataSparkSchema) + df = df.withColumn("createdTSUTC", F.date_format("createdTSUTC", "yyyy-MM-dd HH:mm:ss")) + return df.toPandas() def test_createPredictions(aggDataDf, pipeline): diff --git a/patches/rds/newColumns.py b/patches/rds/newColumns.py index 8aacd6f..5e68d5e 100644 --- a/patches/rds/newColumns.py +++ b/patches/rds/newColumns.py @@ -3,12 +3,12 @@ THIS_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.join(THIS_DIR, '../../')) sys.path.append(os.path.join(THIS_DIR, '../../model/')) -import viral_reddit_posts_utils.configUtils as cu +import viral_reddit_posts_utils.config_utils as cu import sqlUtils as su -cfg_file = cu.findConfig() -cfg = cu.parseConfig(cfg_file) +cfg_file = cu.find_config() +cfg = cu.parse_config(cfg_file) engine = su.makeEngine(cfg) newColumnsStr = """ diff --git a/pyproject.toml b/pyproject.toml index f298fc4..a6b71d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,21 +10,22 @@ dynamic = ["version"] dependencies = [ "boto3==1.26.117", - "matplotlib==3.3.4", - "numpy==1.21.6", # required by pyspark - "pandas==1.3", # 1.3 at least needed for M1 Mac, lower version required by pyspark + "matplotlib==3.8", + "numpy==1.26", + "pandas==2.2.2", # 1.3 at least needed for M1 Mac "pg8000==1.29.4", # this was easier to pip install than psycopg2 "pyarrow==15.0.2", # don't use low versions which pin lower versions of numpy that break on M1 Mac - "pyspark==3.4.0", # using this version because py37 deprecated in pyspark 3.4.0 + "pyspark==3.4.0", "requests==2.31.0", - "scikit-learn==1.0.2", + "scikit-learn==1.4.2", "seaborn==0.11.2", "shap==0.41.0", "sqlalchemy==1.4.46", # originally tried 2.0.10, but this was incompatible with old versions of pandas https://stackoverflow.com/a/75282604/5034651, "viral_reddit_posts_utils @ git+https://github.com/ViralRedditPosts/Utils.git@main", + "Reddit-Scraping @ git+https://github.com/ViralRedditPosts/Reddit-Scraping.git@main", ] -requires-python = "== 3.8.15" +requires-python = "== 3.12.3" authors = [ {name = "Kenneth Myers", email = "myers.kenneth.james@gmail.com"},