Skip to content

Commit

Permalink
Remove pydantic validation of stats
Browse files Browse the repository at this point in the history
  • Loading branch information
amCap1712 committed Jan 10, 2025
1 parent 92b301c commit ba1f8b6
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 34 deletions.
8 changes: 2 additions & 6 deletions listenbrainz/spark/request_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,9 @@ def request_entity_stats(type_, range_, entity, database):
"entity": entity
}

if not database:
if not database and type_ != "listeners":
today = date.today().strftime("%Y%m%d")
if type_ == "listeners":
prefix = f"{entity}_listeners"
else:
prefix = type_
database = f"{prefix}_{range_}_{today}"
database = f"{type_}_{range_}_{today}"

params["database"] = database

Expand Down
10 changes: 5 additions & 5 deletions listenbrainz_spark/stats/incremental/listener/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def generate_stats(self, stats_range: str, from_date: datetime,
prefix = f"entity_listener_{self.entity}_{stats_range}"
existing_aggregate_path = self.get_existing_aggregate_path(stats_range)

only_inc_users = True
only_inc_entities = True

if not hdfs_connection.client.status(existing_aggregate_path, strict=False) or not existing_aggregate_usable:
table = f"{prefix}_full_listens"
Expand All @@ -96,7 +96,7 @@ def generate_stats(self, stats_range: str, from_date: datetime,
schema=BOOKKEEPING_SCHEMA
)
metadata_df.write.mode("overwrite").json(metadata_path)
only_inc_users = False
only_inc_entities = False

full_df = read_files_from_HDFS(existing_aggregate_path)

Expand All @@ -107,15 +107,15 @@ def generate_stats(self, stats_range: str, from_date: datetime,
inc_df = self.aggregate(table, cache_tables)
else:
inc_df = listenbrainz_spark.session.createDataFrame([], schema=self.get_partial_aggregate_schema())
only_inc_users = False
only_inc_entities = False

full_table = f"{prefix}_existing_aggregate"
full_df.createOrReplaceTempView(full_table)

inc_table = f"{prefix}_incremental_aggregate"
inc_df.createOrReplaceTempView(inc_table)

if only_inc_users:
if only_inc_entities:
existing_table = f"{prefix}_filtered_aggregate"
filtered_aggregate_df = self.filter_existing_aggregate(full_table, inc_table)
filtered_aggregate_df.createOrReplaceTempView(existing_table)
Expand All @@ -128,5 +128,5 @@ def generate_stats(self, stats_range: str, from_date: datetime,
combined_df.createOrReplaceTempView(combined_table)
results_df = self.get_top_n(combined_table, top_entity_limit)

return only_inc_users, results_df.toLocalIterator()
return only_inc_entities, results_df.toLocalIterator()

25 changes: 2 additions & 23 deletions listenbrainz_spark/stats/listener/entity.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import json
import logging
from datetime import datetime
from typing import Iterator, Optional, Dict, List
from typing import Iterator, Optional, Dict

from more_itertools import chunked
from pydantic import ValidationError

from data.model.entity_listener_stat import EntityListenerRecord, ArtistListenerRecord, \
ReleaseGroupListenerRecord
from listenbrainz_spark.path import RELEASE_METADATA_CACHE_DATAFRAME, ARTIST_COUNTRY_CODE_DATAFRAME, \
RELEASE_GROUP_METADATA_CACHE_DATAFRAME
from listenbrainz_spark.stats import get_dates_for_stats_range
Expand All @@ -22,11 +18,6 @@
"release_groups": release_group.get_listeners,
}

entity_model_map = {
"artists": ArtistListenerRecord,
"release_groups": ReleaseGroupListenerRecord,
}

entity_cache_map = {
"artists": [ARTIST_COUNTRY_CODE_DATAFRAME],
"release_groups": [RELEASE_METADATA_CACHE_DATAFRAME, RELEASE_GROUP_METADATA_CACHE_DATAFRAME],
Expand All @@ -52,18 +43,6 @@ def get_listener_stats(entity: str, stats_range: str, database: str = None) -> I
to_date=to_date, database=database)


def parse_one_entity_stats(entry, entity: str, stats_range: str) \
-> Optional[EntityListenerRecord]:
try:
data = entry.asDict(recursive=True)
entity_model_map[entity](**data)
return data
except ValidationError:
logger.error(f"""ValidationError while calculating {stats_range} listeners of {entity}.
Data: {json.dumps(data, indent=2)}""", exc_info=True)
return None


def create_messages(only_inc_entities, data, entity: str, stats_range: str, from_date: datetime, to_date: datetime, database: str = None) \
-> Iterator[Optional[Dict]]:
"""
Expand Down Expand Up @@ -96,7 +75,7 @@ def create_messages(only_inc_entities, data, entity: str, stats_range: str, from
for entries in chunked(data, ENTITIES_PER_MESSAGE):
multiple_entity_stats = []
for entry in entries:
processed_stat = parse_one_entity_stats(entry, entity, stats_range)
processed_stat = entry.asDict(recursive=True)
multiple_entity_stats.append(processed_stat)

yield {
Expand Down

0 comments on commit ba1f8b6

Please sign in to comment.