Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incremental user stats #3115

Merged
merged 20 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions listenbrainz/spark/request_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,9 @@ def request_user_stats(type_, range_, entity, database):
if type_ in ["entity", "listener"] and entity:
params["entity"] = entity

if not database:
if not database and type_ != "entity":
today = date.today().strftime("%Y%m%d")
if type_ == "entity":
prefix = entity
elif type_ == "listeners":
if type_ == "listeners":
prefix = f"{entity}_listeners"
else:
prefix = type_
Expand Down
1 change: 1 addition & 0 deletions listenbrainz_spark/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

LISTENBRAINZ_BASE_STATS_DIRECTORY = os.path.join('/', 'stats')
LISTENBRAINZ_SITEWIDE_STATS_DIRECTORY = os.path.join(LISTENBRAINZ_BASE_STATS_DIRECTORY, 'sitewide')
LISTENBRAINZ_USER_STATS_DIRECTORY = os.path.join(LISTENBRAINZ_SITEWIDE_STATS_DIRECTORY, 'user')

# MLHD+ dump files
MLHD_PLUS_RAW_DATA_DIRECTORY = os.path.join("/", "mlhd-raw")
Expand Down
18 changes: 15 additions & 3 deletions listenbrainz_spark/stats/common/listening_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_two_quarters_ago_offset(_date: date) -> relativedelta:
return relativedelta(month=4, day=1)


def _get_time_range_bounds(stats_range: str) -> Tuple[datetime, datetime, relativedelta, str, str]:
def _get_time_range_bounds(stats_range: str, year: int = None) -> Tuple[datetime, datetime, relativedelta, str, str]:
""" Returns the start time, end time, segment step size, python date format and spark
date format to use for calculating the listening activity stats

Expand All @@ -65,13 +65,25 @@ def _get_time_range_bounds(stats_range: str) -> Tuple[datetime, datetime, relati
Python date format reference: https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
Spark date format reference: https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html

If stats_range is set to year_in_music then the year must also be provided.

.. note::

other stats uses a different function (get_dates_for_stats_range) to calculate
time ranges. if making modifications here, remember to check and update that as well
"""
latest_listen_ts = get_latest_listen_ts()

if stats_range == "year_in_music":
if year is None:
raise ValueError("year is required when stats_range is set to year_in_music")
from_date = datetime(year, 1, 1)
to_date = datetime.combine(date(year, 12, 31), time.max)
step = relativedelta(days=+1)
date_format = "%d %B %Y"
spark_date_format = "dd MMMM y"
return from_date, to_date, step, date_format, spark_date_format

if stats_range == "all_time":
# all_time stats range is easy, just return time from LASTFM founding
# to the latest listen we have in spark
Expand Down Expand Up @@ -190,7 +202,7 @@ def _create_time_range_df(from_date, to_date, step, date_format, spark_date_form
time_range_df.createOrReplaceTempView("time_range")


def setup_time_range(stats_range: str) -> Tuple[datetime, datetime, relativedelta, str, str]:
def setup_time_range(stats_range: str, year: int = None) -> Tuple[datetime, datetime, relativedelta, str, str]:
"""
Sets up time range buckets needed to calculate listening activity stats and
returns the start and end time of the time range.
Expand All @@ -203,6 +215,6 @@ def setup_time_range(stats_range: str) -> Tuple[datetime, datetime, relativedelt
will return 1st of last year as the start time and the current date as the
end time in this example.
"""
from_date, to_date, step, date_format, spark_date_format = _get_time_range_bounds(stats_range)
from_date, to_date, step, date_format, spark_date_format = _get_time_range_bounds(stats_range, year)
_create_time_range_df(from_date, to_date, step, date_format, spark_date_format)
return from_date, to_date, step, date_format, spark_date_format
18 changes: 14 additions & 4 deletions listenbrainz_spark/stats/incremental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,23 @@ class IncrementalStats(abc.ABC):
is absent in incremental listens.
"""

def __init__(self, entity: str, stats_range: str):
def __init__(self, entity: str, stats_range: str = None, from_date: datetime = None, to_date: datetime = None):
"""
Args:
entity: The entity for which statistics are generated.
stats_range: The statistics range to calculate the stats for.
from_date: date from which listens to use for this stat
to_date: date until which listens to use for this stat

If both from_date and to_date are specified, they will be used instead of stats_range.
"""
self.entity = entity
self.stats_range = stats_range
self.from_date, self.to_date = get_dates_for_stats_range(stats_range)
if from_date and to_date:
self.stats_range = f"{self.from_date.strftime('%Y%m%d')}_{self.to_date.strftime('%Y%m%d')}"
self.from_date, self.to_date = from_date, to_date
else:
self.stats_range = stats_range
self.from_date, self.to_date = get_dates_for_stats_range(stats_range)
self._cache_tables = []

@abc.abstractmethod
Expand Down Expand Up @@ -145,7 +153,8 @@ def partial_aggregate_usable(self) -> bool:
.json(f"{HDFS_CLUSTER_URI}{metadata_path}") \
.collect()[0]
existing_from_date, existing_to_date = metadata["from_date"], metadata["to_date"]
existing_aggregate_fresh = existing_from_date.date() == self.from_date.date()
existing_aggregate_fresh = existing_from_date.date() == self.from_date.date() \
and existing_to_date.date() <= self.to_date.date()
except AnalysisException:
existing_aggregate_fresh = False

Expand Down Expand Up @@ -182,6 +191,7 @@ def create_partial_aggregate(self) -> DataFrame:
return full_df

def incremental_dump_exists(self) -> bool:
""" Check if incremental dump exists. """
return hdfs_connection.client.status(INCREMENTAL_DUMPS_SAVE_PATH, strict=False)

def create_incremental_aggregate(self) -> DataFrame:
Expand Down
4 changes: 2 additions & 2 deletions listenbrainz_spark/stats/incremental/sitewide/release.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

class ReleaseSitewideEntity(SitewideEntity):

def __init__(self):
super().__init__(entity="releases")
def __init__(self, stats_range):
super().__init__(entity="releases", stats_range=stats_range)

def get_cache_tables(self) -> List[str]:
return [RELEASE_METADATA_CACHE_DATAFRAME]
Expand Down
Empty file.
124 changes: 124 additions & 0 deletions listenbrainz_spark/stats/incremental/user/artist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from typing import List

from pyspark.sql.types import StructType, StructField, StringType, IntegerType

from listenbrainz_spark.path import ARTIST_COUNTRY_CODE_DATAFRAME
from listenbrainz_spark.stats import run_query
from listenbrainz_spark.stats.incremental.user.entity import UserEntity


class ArtistUserEntity(UserEntity):
""" See base class IncrementalStats for documentation. """

def __init__(self, stats_range, database, message_type, from_date=None, to_date=None):
super().__init__(entity="artists", stats_range=stats_range, database=database, message_type=message_type,
from_date=from_date, to_date=to_date)

def get_cache_tables(self) -> List[str]:
return [ARTIST_COUNTRY_CODE_DATAFRAME]

def get_partial_aggregate_schema(self):
return StructType([
StructField("user_id", IntegerType(), nullable=False),
StructField("artist_name", StringType(), nullable=False),
StructField("artist_mbid", StringType(), nullable=True),
StructField("listen_count", IntegerType(), nullable=False),
])

def aggregate(self, table, cache_tables):
cache_table = cache_tables[0]
result = run_query(f"""
WITH exploded_listens AS (
SELECT user_id
, artist_name AS artist_credit_name
, explode_outer(artist_credit_mbids) AS artist_mbid
FROM {table}
), listens_with_mb_data as (
SELECT user_id
, COALESCE(at.artist_name, el.artist_credit_name) AS artist_name
, el.artist_mbid
FROM exploded_listens el
LEFT JOIN {cache_table} at
ON el.artist_mbid = at.artist_mbid
)
SELECT user_id
-- we group by lower(artist_name) and pick the first artist name for cases where
-- the artist name differs in case. for mapped listens the artist name from MB will
-- be used. for unmapped listens we can't know which case is correct so use any. note
-- that due to presence of artist mbid as the third group, mapped and unmapped listens
-- will always be separately grouped therefore first will work fine for unmapped
-- listens and doesn't matter for mapped ones.
, first(artist_name) AS artist_name
, artist_mbid
, count(*) AS listen_count
FROM listens_with_mb_data
GROUP BY user_id
, lower(artist_name)
, artist_mbid
""")
return result

def combine_aggregates(self, existing_aggregate, incremental_aggregate):
query = f"""
WITH intermediate_table AS (
SELECT user_id
, artist_name
, artist_mbid
, listen_count
FROM {existing_aggregate}
UNION ALL
SELECT user_id
, artist_name
, artist_mbid
, listen_count
FROM {incremental_aggregate}
)
SELECT user_id
, first(artist_name) AS artist_name
, artist_mbid
, sum(listen_count) as listen_count
FROM intermediate_table
GROUP BY user_id
, lower(artist_name)
, artist_mbid
"""
return run_query(query)

def get_top_n(self, final_aggregate, N):
query = f"""
WITH entity_count AS (
SELECT user_id
, count(*) AS artists_count
FROM {final_aggregate}
GROUP BY user_id
), ranked_stats AS (
SELECT user_id
, artist_name
, artist_mbid
, listen_count
, row_number() OVER (PARTITION BY user_id ORDER BY listen_count DESC) AS rank
FROM {final_aggregate}
), grouped_stats AS (
SELECT user_id
, sort_array(
collect_list(
struct(
listen_count
, artist_name
, artist_mbid
)
)
, false
) as artists
FROM ranked_stats
WHERE rank <= {N}
GROUP BY user_id
)
SELECT user_id
, artists_count
, artists
FROM grouped_stats
JOIN entity_count
USING (user_id)
"""
return run_query(query)
117 changes: 117 additions & 0 deletions listenbrainz_spark/stats/incremental/user/daily_activity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import calendar
import itertools
import json
import logging
from typing import List

from pydantic import ValidationError
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

import listenbrainz_spark
from data.model.common_stat_spark import UserStatRecords
from data.model.user_daily_activity import DailyActivityRecord
from listenbrainz_spark.stats import run_query
from listenbrainz_spark.stats.incremental.user.entity import UserEntity

logger = logging.getLogger(__name__)


class DailyActivityUserEntity(UserEntity):
""" See base class IncrementalStats for documentation. """

def __init__(self, stats_range, database, message_type, from_date=None, to_date=None):
super().__init__(
entity="daily_activity", stats_range=stats_range, database=database,
message_type=message_type, from_date=from_date, to_date=to_date
)
self.setup_time_range()

def setup_time_range(self):
""" Genarate a dataframe containing hours of all days of the week. """
weekdays = [calendar.day_name[day] for day in range(0, 7)]
hours = [hour for hour in range(0, 24)]
time_range = itertools.product(weekdays, hours)
time_range_df = listenbrainz_spark.session.createDataFrame(time_range, schema=["day", "hour"])
time_range_df.createOrReplaceTempView("time_range")

def get_cache_tables(self) -> List[str]:
return []

def get_partial_aggregate_schema(self):
return StructType([
StructField("user_id", IntegerType(), nullable=False),
StructField("day", StringType(), nullable=False),
StructField("hour", IntegerType(), nullable=False),
StructField("listen_count", IntegerType(), nullable=False),
])

def aggregate(self, table, cache_tables):
result = run_query(f"""
SELECT user_id
, date_format(listened_at, 'EEEE') as day
, date_format(listened_at, 'H') as hour
, count(listened_at) AS listen_count
FROM {table}
GROUP BY user_id
, day
, hour
""")
return result

def combine_aggregates(self, existing_aggregate, incremental_aggregate):
query = f"""
WITH intermediate_table AS (
SELECT user_id
, day
, hour
, listen_count
FROM {existing_aggregate}
UNION ALL
SELECT user_id
, day
, hour
, listen_count
FROM {incremental_aggregate}
)
SELECT user_id
, day
, hour
, sum(listen_count) as listen_count
FROM intermediate_table
GROUP BY user_id
, day
, hour
"""
return run_query(query)

def get_top_n(self, final_aggregate, N):
query = f"""
SELECT user_id
, sort_array(
collect_list(
struct(
day
, hour
, COALESCE(listen_count, 0) AS listen_count
)
)
) AS daily_activity
FROM time_range
LEFT JOIN {final_aggregate}
USING (day, hour)
GROUP BY user_id
"""
return run_query(query)

def parse_one_user_stats(self, entry: dict):
try:
UserStatRecords[DailyActivityRecord](
user_id=entry["user_id"],
data=entry["daily_activity"]
)
return {
"user_id": entry["user_id"],
"data": entry["daily_activity"]
}
except ValidationError:
logger.error("Invalid entry in entity stats:", exc_info=True)
Loading
Loading