From f3769203e6eae510c0d75290de2c38bcc896d2a0 Mon Sep 17 00:00:00 2001 From: Kartik Ohri Date: Thu, 9 Jan 2025 22:21:31 +0530 Subject: [PATCH] Refactor create messages and stats validation into class --- listenbrainz_spark/path.py | 4 +- .../stats/incremental/listener/artist.py | 25 ++-- .../stats/incremental/listener/entity.py | 136 +++--------------- .../incremental/listener/release_group.py | 38 +++-- .../stats/incremental/user/entity.py | 19 ++- listenbrainz_spark/stats/listener/entity.py | 83 +---------- listenbrainz_spark/stats/user/__init__.py | 1 - 7 files changed, 60 insertions(+), 246 deletions(-) diff --git a/listenbrainz_spark/path.py b/listenbrainz_spark/path.py index ef8d80a8f5..630a7d1653 100644 --- a/listenbrainz_spark/path.py +++ b/listenbrainz_spark/path.py @@ -8,9 +8,7 @@ LISTENBRAINZ_BASE_STATS_DIRECTORY = os.path.join('/', 'stats') LISTENBRAINZ_SITEWIDE_STATS_DIRECTORY = os.path.join(LISTENBRAINZ_BASE_STATS_DIRECTORY, 'sitewide') LISTENBRAINZ_USER_STATS_DIRECTORY = os.path.join(LISTENBRAINZ_BASE_STATS_DIRECTORY, 'user') - -LISTENBRAINZ_LISTENER_STATS_AGG_DIRECTORY = os.path.join('/', 'listener_stats_aggregates') -LISTENBRAINZ_LISTENER_STATS_BOOKKEEPING_DIRECTORY = os.path.join('/', 'listener_stats_bookkeeping') +LISTENBRAINZ_LISTENER_STATS_DIRECTORY = os.path.join(LISTENBRAINZ_BASE_STATS_DIRECTORY, 'listener') # MLHD+ dump files MLHD_PLUS_RAW_DATA_DIRECTORY = os.path.join("/", "mlhd-raw") diff --git a/listenbrainz_spark/stats/incremental/listener/artist.py b/listenbrainz_spark/stats/incremental/listener/artist.py index 9aa8e17cce..a514bb2ac1 100644 --- a/listenbrainz_spark/stats/incremental/listener/artist.py +++ b/listenbrainz_spark/stats/incremental/listener/artist.py @@ -5,35 +5,26 @@ from listenbrainz_spark.path import ARTIST_COUNTRY_CODE_DATAFRAME from listenbrainz_spark.stats import run_query from listenbrainz_spark.stats.incremental.listener.entity import EntityListener -from listenbrainz_spark.stats.incremental.user.entity import UserEntity class ArtistEntityListener(EntityListener): - def __init__(self): - super().__init__(entity="artists") + def __init__(self, stats_range, database): + super().__init__(entity="artists", stats_range=stats_range, database=database, message_type="entity_listener") def get_cache_tables(self) -> List[str]: return [ARTIST_COUNTRY_CODE_DATAFRAME] def get_partial_aggregate_schema(self): return StructType([ - StructField('artist_name', StringType(), nullable=False), - StructField('artist_mbid', StringType(), nullable=True), - StructField('user_id', IntegerType(), nullable=False), - StructField('listen_count', IntegerType(), nullable=False), + StructField("artist_name", StringType(), nullable=False), + StructField("artist_mbid", StringType(), nullable=True), + StructField("user_id", IntegerType(), nullable=False), + StructField("listen_count", IntegerType(), nullable=False), ]) - def filter_existing_aggregate(self, existing_aggregate, incremental_aggregate): - query = f""" - WITH incremental_artists AS ( - SELECT DISTINCT artist_mbid FROM {incremental_aggregate} - ) - SELECT * - FROM {existing_aggregate} ea - WHERE EXISTS(SELECT 1 FROM incremental_artists iu WHERE iu.artist_mbid = ea.artist_mbid) - """ - return run_query(query) + def get_entity_id(self): + return "artist_mbid" def aggregate(self, table, cache_tables): cache_table = cache_tables[0] diff --git a/listenbrainz_spark/stats/incremental/listener/entity.py b/listenbrainz_spark/stats/incremental/listener/entity.py index 406cd42643..ef727d9fa1 100644 --- a/listenbrainz_spark/stats/incremental/listener/entity.py +++ b/listenbrainz_spark/stats/incremental/listener/entity.py @@ -1,132 +1,32 @@ import abc import logging -from datetime import datetime -from pathlib import Path -from typing import List - -from pyspark.errors import AnalysisException -from pyspark.sql import DataFrame -from pyspark.sql.types import StructType, StructField, TimestampType - -import listenbrainz_spark -from listenbrainz_spark import hdfs_connection -from listenbrainz_spark.config import HDFS_CLUSTER_URI -from listenbrainz_spark.path import INCREMENTAL_DUMPS_SAVE_PATH, \ - LISTENBRAINZ_LISTENER_STATS_AGG_DIRECTORY, LISTENBRAINZ_LISTENER_STATS_BOOKKEEPING_DIRECTORY -from listenbrainz_spark.stats import run_query -from listenbrainz_spark.utils import read_files_from_HDFS, get_listens_from_dump +from datetime import date +from typing import Optional +from listenbrainz_spark.path import LISTENBRAINZ_LISTENER_STATS_DIRECTORY +from listenbrainz_spark.stats.incremental.user.entity import UserEntity logger = logging.getLogger(__name__) -BOOKKEEPING_SCHEMA = StructType([ - StructField('from_date', TimestampType(), nullable=False), - StructField('to_date', TimestampType(), nullable=False), - StructField('created', TimestampType(), nullable=False), -]) - -class EntityListener(abc.ABC): - - def __init__(self, entity): - self.entity = entity - - def get_existing_aggregate_path(self, stats_range) -> str: - return f"{LISTENBRAINZ_LISTENER_STATS_AGG_DIRECTORY}/{self.entity}/{stats_range}" - def get_bookkeeping_path(self, stats_range) -> str: - return f"{LISTENBRAINZ_LISTENER_STATS_BOOKKEEPING_DIRECTORY}/{self.entity}/{stats_range}" +class EntityListener(UserEntity, abc.ABC): - def get_partial_aggregate_schema(self) -> StructType: - raise NotImplementedError() + def __init__(self, entity: str, stats_range: str, database: Optional[str], message_type: Optional[str]): + if not database: + database = f"{self.entity}_listeners_{self.stats_range}_{date.today().strftime('%Y%m%d')}" + super().__init__(entity, stats_range, database, message_type) - def aggregate(self, table, cache_tables) -> DataFrame: - raise NotImplementedError() + def get_table_prefix(self) -> str: + return f"{self.entity}_listener_{self.stats_range}" - def filter_existing_aggregate(self, existing_aggregate, incremental_aggregate): - raise NotImplementedError() + def get_base_path(self) -> str: + return LISTENBRAINZ_LISTENER_STATS_DIRECTORY - def combine_aggregates(self, existing_aggregate, incremental_aggregate) -> DataFrame: + def get_entity_id(self): raise NotImplementedError() - def get_top_n(self, final_aggregate, N) -> DataFrame: - raise NotImplementedError() - - def get_cache_tables(self) -> List[str]: - raise NotImplementedError() - - def generate_stats(self, stats_range: str, from_date: datetime, - to_date: datetime, top_entity_limit: int): - cache_tables = [] - for idx, df_path in enumerate(self.get_cache_tables()): - df_name = f"entity_data_cache_{idx}" - cache_tables.append(df_name) - read_files_from_HDFS(df_path).createOrReplaceTempView(df_name) - - metadata_path = self.get_bookkeeping_path(stats_range) - try: - metadata = listenbrainz_spark \ - .session \ - .read \ - .schema(BOOKKEEPING_SCHEMA) \ - .json(f"{HDFS_CLUSTER_URI}{metadata_path}") \ - .collect()[0] - existing_from_date, existing_to_date = metadata["from_date"], metadata["to_date"] - existing_aggregate_usable = existing_from_date.date() == from_date.date() - except AnalysisException: - existing_aggregate_usable = False - logger.info("Existing partial aggregate not found!") - - prefix = f"entity_listener_{self.entity}_{stats_range}" - existing_aggregate_path = self.get_existing_aggregate_path(stats_range) - - only_inc_entities = True - - if not hdfs_connection.client.status(existing_aggregate_path, strict=False) or not existing_aggregate_usable: - table = f"{prefix}_full_listens" - get_listens_from_dump(from_date, to_date, include_incremental=False).createOrReplaceTempView(table) - - logger.info("Creating partial aggregate from full dump listens") - hdfs_connection.client.makedirs(Path(existing_aggregate_path).parent) - full_df = self.aggregate(table, cache_tables) - full_df.write.mode("overwrite").parquet(existing_aggregate_path) - - hdfs_connection.client.makedirs(Path(metadata_path).parent) - metadata_df = listenbrainz_spark.session.createDataFrame( - [(from_date, to_date, datetime.now())], - schema=BOOKKEEPING_SCHEMA - ) - metadata_df.write.mode("overwrite").json(metadata_path) - only_inc_entities = False - - full_df = read_files_from_HDFS(existing_aggregate_path) - - if hdfs_connection.client.status(INCREMENTAL_DUMPS_SAVE_PATH, strict=False): - table = f"{prefix}_incremental_listens" - read_files_from_HDFS(INCREMENTAL_DUMPS_SAVE_PATH) \ - .createOrReplaceTempView(table) - inc_df = self.aggregate(table, cache_tables) - else: - inc_df = listenbrainz_spark.session.createDataFrame([], schema=self.get_partial_aggregate_schema()) - only_inc_entities = False - - full_table = f"{prefix}_existing_aggregate" - full_df.createOrReplaceTempView(full_table) - - inc_table = f"{prefix}_incremental_aggregate" - inc_df.createOrReplaceTempView(inc_table) - - if only_inc_entities: - existing_table = f"{prefix}_filtered_aggregate" - filtered_aggregate_df = self.filter_existing_aggregate(full_table, inc_table) - filtered_aggregate_df.createOrReplaceTempView(existing_table) - else: - existing_table = full_table - - combined_df = self.combine_aggregates(existing_table, inc_table) - - combined_table = f"{prefix}_combined_aggregate" - combined_df.createOrReplaceTempView(combined_table) - results_df = self.get_top_n(combined_table, top_entity_limit) + def items_per_message(self): + return 10000 - return only_inc_entities, results_df.toLocalIterator() - \ No newline at end of file + def parse_one_user_stats(self, entry: dict): + raise entry diff --git a/listenbrainz_spark/stats/incremental/listener/release_group.py b/listenbrainz_spark/stats/incremental/listener/release_group.py index deda7d91c0..eb587ad27d 100644 --- a/listenbrainz_spark/stats/incremental/listener/release_group.py +++ b/listenbrainz_spark/stats/incremental/listener/release_group.py @@ -2,43 +2,37 @@ from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType -from listenbrainz_spark.path import ARTIST_COUNTRY_CODE_DATAFRAME, RELEASE_METADATA_CACHE_DATAFRAME, \ +from listenbrainz_spark.path import RELEASE_METADATA_CACHE_DATAFRAME, \ RELEASE_GROUP_METADATA_CACHE_DATAFRAME from listenbrainz_spark.stats import run_query from listenbrainz_spark.stats.incremental.listener.entity import EntityListener -from listenbrainz_spark.stats.incremental.user.entity import UserEntity class ReleaseGroupEntityListener(EntityListener): - def __init__(self): - super().__init__(entity="release_groups") + def __init__(self, stats_range, database): + super().__init__( + entity="release_groups", stats_range=stats_range, + database=database, message_type="entity_listener" + ) def get_cache_tables(self) -> List[str]: return [RELEASE_METADATA_CACHE_DATAFRAME, RELEASE_GROUP_METADATA_CACHE_DATAFRAME] def get_partial_aggregate_schema(self): return StructType([ - StructField('release_group_mbid', StringType(), nullable=False), - StructField('release_group_name', StringType(), nullable=False), - StructField('release_group_artist_name', StringType(), nullable=False), - StructField('artist_credit_mbids', ArrayType(StringType()), nullable=False), - StructField('caa_id', IntegerType(), nullable=True), - StructField('caa_release_mbid', StringType(), nullable=True), - StructField('user_id', IntegerType(), nullable=False), - StructField('listen_count', IntegerType(), nullable=False), + StructField("release_group_mbid", StringType(), nullable=False), + StructField("release_group_name", StringType(), nullable=False), + StructField("release_group_artist_name", StringType(), nullable=False), + StructField("artist_credit_mbids", ArrayType(StringType()), nullable=False), + StructField("caa_id", IntegerType(), nullable=True), + StructField("caa_release_mbid", StringType(), nullable=True), + StructField("user_id", IntegerType(), nullable=False), + StructField("listen_count", IntegerType(), nullable=False), ]) - def filter_existing_aggregate(self, existing_aggregate, incremental_aggregate): - query = f""" - WITH incremental_release_groups AS ( - SELECT DISTINCT release_group_mbid FROM {incremental_aggregate} - ) - SELECT * - FROM {existing_aggregate} ea - WHERE EXISTS(SELECT 1 FROM incremental_release_groups iu WHERE iu.release_group_mbid = ea.release_group_mbid) - """ - return run_query(query) + def get_entity_id(self): + return "release_group_mbid" def aggregate(self, table, cache_tables): rel_cache_table = cache_tables[0] diff --git a/listenbrainz_spark/stats/incremental/user/entity.py b/listenbrainz_spark/stats/incremental/user/entity.py index be29ca7702..240f5ecf26 100644 --- a/listenbrainz_spark/stats/incremental/user/entity.py +++ b/listenbrainz_spark/stats/incremental/user/entity.py @@ -1,5 +1,4 @@ import abc -import json import logging from datetime import date, datetime from typing import Optional, Iterator, Dict, Tuple @@ -15,10 +14,8 @@ from listenbrainz_spark.path import LISTENBRAINZ_USER_STATS_DIRECTORY from listenbrainz_spark.stats import run_query from listenbrainz_spark.stats.incremental import IncrementalStats -from listenbrainz_spark.stats.user import USERS_PER_MESSAGE from listenbrainz_spark.utils import read_files_from_HDFS - logger = logging.getLogger(__name__) entity_model_map = { @@ -31,7 +28,7 @@ class UserEntity(IncrementalStats, abc.ABC): - def __init__(self, entity: str, stats_range: str = None, database: str = None, message_type: str = None, + def __init__(self, entity: str, stats_range: str = None, database: str = None, message_type: str = None, from_date: datetime = None, to_date: datetime = None): super().__init__(entity, stats_range, from_date, to_date) if database: @@ -46,14 +43,22 @@ def get_base_path(self) -> str: def get_table_prefix(self) -> str: return f"user_{self.entity}_{self.stats_range}" + def get_entity_id(self): + return "user_id" + + def items_per_message(self): + """ Get the number of items to chunk per message """ + return 25 + def filter_existing_aggregate(self, existing_aggregate, incremental_aggregate): + entity_id = self.get_entity_id() query = f""" WITH incremental_users AS ( - SELECT DISTINCT user_id FROM {incremental_aggregate} + SELECT DISTINCT {entity_id} FROM {incremental_aggregate} ) SELECT * FROM {existing_aggregate} ea - WHERE EXISTS(SELECT 1 FROM incremental_users iu WHERE iu.user_id = ea.user_id) + WHERE EXISTS(SELECT 1 FROM incremental_users iu WHERE iu.{entity_id} = ea.{entity_id}) """ return run_query(query) @@ -131,7 +136,7 @@ def create_messages(self, only_inc_users, results: DataFrame) -> Iterator[Dict]: to_ts = int(self.to_date.timestamp()) data = results.toLocalIterator() - for entries in chunked(data, USERS_PER_MESSAGE): + for entries in chunked(data, self.items_per_message()): multiple_user_stats = [] for entry in entries: row = entry.asDict(recursive=True) diff --git a/listenbrainz_spark/stats/listener/entity.py b/listenbrainz_spark/stats/listener/entity.py index d0f8434f00..c681d7c01a 100644 --- a/listenbrainz_spark/stats/listener/entity.py +++ b/listenbrainz_spark/stats/listener/entity.py @@ -1,95 +1,22 @@ import logging -from datetime import datetime from typing import Iterator, Optional, Dict -from more_itertools import chunked - -from listenbrainz_spark.path import RELEASE_METADATA_CACHE_DATAFRAME, ARTIST_COUNTRY_CODE_DATAFRAME, \ - RELEASE_GROUP_METADATA_CACHE_DATAFRAME -from listenbrainz_spark.stats import get_dates_for_stats_range from listenbrainz_spark.stats.incremental.listener.artist import ArtistEntityListener from listenbrainz_spark.stats.incremental.listener.release_group import ReleaseGroupEntityListener -from listenbrainz_spark.stats.listener import artist, release_group logger = logging.getLogger(__name__) -entity_handler_map = { - "artists": artist.get_listeners, - "release_groups": release_group.get_listeners, -} - -entity_cache_map = { - "artists": [ARTIST_COUNTRY_CODE_DATAFRAME], - "release_groups": [RELEASE_METADATA_CACHE_DATAFRAME, RELEASE_GROUP_METADATA_CACHE_DATAFRAME], -} - incremental_entity_obj_map = { - "artists": ArtistEntityListener(), - "release_groups": ReleaseGroupEntityListener(), + "artists": ArtistEntityListener, + "release_groups": ReleaseGroupEntityListener, } -ENTITIES_PER_MESSAGE = 10000 # number of entities per message NUMBER_OF_TOP_LISTENERS = 10 # number of top listeners to retain for user stats def get_listener_stats(entity: str, stats_range: str, database: str = None) -> Iterator[Optional[Dict]]: """ Get the top listeners for all entity for specified stats_range """ logger.debug(f"Calculating {entity}_listeners_{stats_range}...") - - from_date, to_date = get_dates_for_stats_range(stats_range) - entity_obj = incremental_entity_obj_map[entity] - only_inc_entities, data = entity_obj.generate_stats(stats_range, from_date, to_date, NUMBER_OF_TOP_LISTENERS) - return create_messages(only_inc_entities, data=data, entity=entity, stats_range=stats_range, from_date=from_date, - to_date=to_date, database=database) - - -def create_messages(only_inc_entities, data, entity: str, stats_range: str, from_date: datetime, to_date: datetime, database: str = None) \ - -> Iterator[Optional[Dict]]: - """ - Create messages to send the data to the webserver via RabbitMQ - - Args: - data: Data to sent to the webserver - entity: The entity for which statistics are calculated, i.e 'artists', - 'releases' or 'recordings' - stats_range: The range for which the statistics have been calculated - from_date: The start time of the stats - to_date: The end time of the stats - database: the name of the database in which the webserver should store the data - - Returns: - messages: A list of messages to be sent via RabbitMQ - """ - if database is None: - database = f"{entity}_listeners_{stats_range}_{datetime.today().strftime('%Y%m%d')}" - - if only_inc_entities: - yield { - "type": "couchdb_data_start", - "database": database - } - - from_ts = int(from_date.timestamp()) - to_ts = int(to_date.timestamp()) - - for entries in chunked(data, ENTITIES_PER_MESSAGE): - multiple_entity_stats = [] - for entry in entries: - processed_stat = entry.asDict(recursive=True) - multiple_entity_stats.append(processed_stat) - - yield { - "type": "entity_listener", - "stats_range": stats_range, - "from_ts": from_ts, - "to_ts": to_ts, - "entity": entity, - "data": multiple_entity_stats, - "database": database - } - - if only_inc_entities: - yield { - "type": "couchdb_data_end", - "database": database - } + entity_cls = incremental_entity_obj_map[entity] + entity_obj = entity_cls(stats_range, database) + return entity_obj.main(NUMBER_OF_TOP_LISTENERS) diff --git a/listenbrainz_spark/stats/user/__init__.py b/listenbrainz_spark/stats/user/__init__.py index 8a201fcddd..e69de29bb2 100644 --- a/listenbrainz_spark/stats/user/__init__.py +++ b/listenbrainz_spark/stats/user/__init__.py @@ -1 +0,0 @@ -USERS_PER_MESSAGE = 25