-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/SonyShrestha/VBP_Joint_Project
- Loading branch information
Showing
2 changed files
with
222 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from pyspark.sql import SparkSession | ||
from pyspark.sql.functions import udf, rand, col | ||
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, FloatType | ||
|
||
import numpy as np | ||
from datetime import datetime, timedelta | ||
from pyspark.sql.functions import struct | ||
|
||
|
||
spark = SparkSession.builder \ | ||
.appName("Data Processing with Spark") \ | ||
.getOrCreate() | ||
|
||
|
||
df = spark.read.parquet('purchases_nearing_expiry') | ||
|
||
|
||
def generate_expiry_date(purchase_date): | ||
purchase_datetime = datetime.strptime(purchase_date, '%Y-%m-%d') | ||
added_days = np.random.randint(30, 365) | ||
return (purchase_datetime + timedelta(days=added_days)).strftime('%Y-%m-%d') | ||
|
||
expiry_udf = udf(generate_expiry_date, StringType()) | ||
df = df.withColumn("unit_price", col("unit_price").cast(FloatType())) | ||
|
||
df = df.withColumn('expiry_date', expiry_udf(col('purchase_date'))) | ||
|
||
|
||
def calculate_expected_price(unit_price, percentage_consumed): | ||
discount_factor = 1 - (percentage_consumed * np.random.uniform(0.1, 0.5)) | ||
return float(unit_price) * discount_factor | ||
|
||
# Register the UDF with the appropriate return type | ||
|
||
price_udf = udf(calculate_expected_price, DoubleType()) | ||
|
||
# Apply the UDF | ||
df = df.withColumn('percentage_consumed', rand()) | ||
df = df.withColumn('expected_price', price_udf(col('unit_price'), col('percentage_consumed'))) | ||
|
||
|
||
# def generate_optional_fields(customer_id, expiry_date): | ||
# if np.random.rand() > 0.2: | ||
# expiry_datetime = datetime.strptime(expiry_date, '%Y-%m-%d') | ||
# subtract_days = np.random.randint(1, 15) | ||
# selling_date = (expiry_datetime - timedelta(days=subtract_days)).strftime('%Y-%m-%d') | ||
# return customer_id, selling_date # Modify logic to generate a different customer ID | ||
# return None, None | ||
|
||
def generate_optional_fields(customer_id, expiry_date): | ||
if np.random.rand() > 0.2: | ||
expiry_datetime = datetime.strptime(expiry_date, '%Y-%m-%d') | ||
subtract_days = np.random.randint(1, 15) | ||
selling_date = (expiry_datetime - timedelta(days=subtract_days)).strftime('%Y-%m-%d') | ||
# Simulate a different customer ID; ensure logic here is valid for your use case | ||
new_customer_id = str(int(customer_id) + 1) # Example modification | ||
return (new_customer_id, selling_date) | ||
return (None, None) | ||
|
||
# fields_udf = udf(generate_optional_fields, StringType()) | ||
# Register UDF with a struct return type | ||
fields_udf = udf(generate_optional_fields, StructType([ | ||
StructField("buying_customer_id", StringType(), True), | ||
StructField("selling_date", StringType(), True) | ||
])) | ||
|
||
# df = df.withColumn('new_fields', fields_udf(col('customer_id'), col('expiry_date'))) | ||
# df = df.withColumn('buying_customer_id', col('new_fields').getItem(0)) | ||
# df = df.withColumn('selling_date', col('new_fields').getItem(1)) | ||
# df = df.drop('new_fields') | ||
|
||
df = df.withColumn('new_fields', fields_udf(col('customer_id'), col('expiry_date'))) | ||
df = df.withColumn('buying_customer_id', col('new_fields').getItem('buying_customer_id')) | ||
df = df.withColumn('selling_date', col('new_fields').getItem('selling_date')) | ||
df = df.drop('new_fields') | ||
|
||
df.show() | ||
df.write.mode('overwrite').parquet('platform_customer_pricing_data_output') | ||
spark.stop() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
from pyspark.sql import SparkSession | ||
from pyspark.sql.functions import udf, col, isnull, lit | ||
from pyspark.sql.types import DoubleType, IntegerType | ||
import datetime | ||
import json | ||
import requests | ||
import numpy as np | ||
import shutil | ||
import os | ||
|
||
# Initialize the Spark session | ||
spark = SparkSession.builder \ | ||
.appName("Dynamic Pricing Model") \ | ||
.getOrCreate() | ||
|
||
# Load configuration | ||
with open("business_config.json", "r") as config_file: | ||
config = json.load(config_file) | ||
|
||
# Function to check if today is a holiday | ||
def is_today_a_holiday(): | ||
country_code = config["country_code"] | ||
current_year = datetime.datetime.now().year | ||
today = datetime.datetime.now().strftime("%Y-%m-%d") | ||
url = f"https://date.nager.at/api/v3/publicholidays/{current_year}/{country_code}" | ||
try: | ||
response = requests.get(url) | ||
response.raise_for_status() | ||
holidays = response.json() | ||
return any(holiday['date'] == today for holiday in holidays) | ||
except requests.RequestException as e: | ||
print(f"Error fetching holiday data: {e}") | ||
return False | ||
|
||
# Check if today is a holiday | ||
is_holiday_today = is_today_a_holiday() | ||
|
||
# Define the UDF to calculate days to expiry based on today's date | ||
def get_days_to_expiry(expiry_date): | ||
today_date = datetime.date.today() | ||
if isinstance(expiry_date, datetime.date): | ||
return (expiry_date - today_date).days | ||
else: | ||
expiry_date = datetime.datetime.strptime(str(expiry_date), "%Y-%m-%d").date() | ||
return (expiry_date - today_date).days | ||
|
||
days_to_expiry_udf = udf(get_days_to_expiry, IntegerType()) | ||
|
||
# Define the UDF for longevity factor | ||
# def longevity_factor(avg_expiry_days): | ||
# return float(1 - np.exp(-avg_expiry_days / 365)) | ||
|
||
|
||
def longevity_factor(avg_expiry_days): | ||
return float(np.exp(-avg_expiry_days / 365)*2) | ||
|
||
longevity_factor_udf = udf(longevity_factor, DoubleType()) | ||
|
||
# Define the UDF for rule-based pricing | ||
|
||
def rule_based_pricing(days_to_expiry, consumption_rate, base_price, avg_expiry_days): | ||
# Configuration for discounts and thresholds | ||
holiday_discount = config["pricing_rules"]['holiday_discount'] | ||
threshold_days_high = config["pricing_rules"]["threshold_days_high"] | ||
discount_high = config["pricing_rules"]["discount_high"] | ||
threshold_days_medium = config["pricing_rules"]["threshold_days_medium"] | ||
discount_medium = config["pricing_rules"]["discount_medium"] | ||
discount_low_high_consumption = config["pricing_rules"]["discount_low_high_consumption"] | ||
discount_low_low_consumption = config["pricing_rules"]["discount_low_low_consumption"] | ||
min_price = config["pricing_rules"].get('min_price', 0) # Minimum price floor | ||
|
||
# Determine if it's a holiday for possible holiday discount or premium | ||
base_discount = holiday_discount if is_holiday_today else 1.0 | ||
|
||
# Calculate the longevity scale to adjust pricing based on average expiry days | ||
longevity_scale = longevity_factor(avg_expiry_days) | ||
|
||
# Determine discount factor based on expiry thresholds | ||
if days_to_expiry > threshold_days_high: | ||
discount_factor = discount_medium | ||
elif days_to_expiry <= threshold_days_high and days_to_expiry > threshold_days_medium: | ||
discount_factor = discount_high | ||
else: | ||
discount_factor = discount_low_high_consumption if consumption_rate > 0.5 else discount_low_low_consumption | ||
|
||
# Calculate the total discount to be applied | ||
# total_discount = discount_factor * longevity_scale * base_discount | ||
total_discount = discount_factor * longevity_scale * base_discount | ||
|
||
|
||
# Ensure total discount does not exceed 100% | ||
total_discount = min(total_discount, 1) | ||
|
||
# Calculate final price ensuring it does not fall below the minimum price | ||
final_price = max(base_price * (1 - total_discount) , min_price) | ||
|
||
return final_price | ||
|
||
|
||
|
||
|
||
pricing_udf = udf(rule_based_pricing, DoubleType()) | ||
|
||
# Define the path to the parquet file | ||
parquet_path = "platform_customer_pricing_data_output" | ||
df = spark.read.parquet(parquet_path) | ||
|
||
df.show(20) | ||
# Filter and modify specific rows | ||
df_to_update = df.filter(isnull("selling_date") & isnull("buying_customer_id")) | ||
df_to_update = df_to_update.withColumn("days_to_expiry", days_to_expiry_udf(col("expected_expiry_date"))) | ||
df_to_update = df_to_update.withColumn("longevity_scale", longevity_factor_udf(col("avg_expiry_days"))) | ||
df_to_update = df_to_update.withColumn("dynamic_price", pricing_udf(col("days_to_expiry"), col("percentage_consumed"), col("unit_price"), col("longevity_scale"))) | ||
# df_to_update = df_to_update.drop("days_to_expiry") # Drop | ||
# Extract unchanged rows | ||
df_unchanged = df.filter(~(isnull("selling_date") & isnull("buying_customer_id"))) | ||
|
||
df_unchanged = df_unchanged.withColumn("days_to_expiry", lit(None).cast(IntegerType())) | ||
df_unchanged = df_unchanged.withColumn("longevity_scale", lit(None).cast(DoubleType())) | ||
df_unchanged = df_unchanged.withColumn("dynamic_price", lit(None).cast(DoubleType())) | ||
# Combine updated and unchanged data | ||
df_final = df_to_update.union(df_unchanged) | ||
df_final.show() | ||
# Write the combined DataFrame to a temporary location | ||
temp_output_path = "temp_dynamic_pricing" | ||
df_final.write.mode("overwrite").parquet(temp_output_path) | ||
|
||
# Replace the original file with the updated file | ||
def replace_original_with_temp(original_path, temp_path): | ||
try: | ||
if os.path.exists(original_path): | ||
shutil.rmtree(original_path) | ||
os.rename(temp_path, original_path) | ||
except Exception as e: | ||
print(f"Failed to replace original file with updated file: {e}") | ||
raise | ||
|
||
replace_original_with_temp(parquet_path, temp_output_path) | ||
|
||
# Stop the Spark session | ||
spark.stop() |