Skip to content

Commit

Permalink
An even nicer summary dataframe and markdown summary
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorKraevTransferwise committed Jan 10, 2025
1 parent 390dc97 commit 62ddfe8
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 24 deletions.
82 changes: 82 additions & 0 deletions wise_pizza/dataframe_with_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import logging

import pandas as pd

logger = logging.getLogger(__name__)


class DataFrameWithMetadata(pd.DataFrame):
def __init__(
self,
*args,
name: str = None,
description: str = None,
column_descriptions=None,
**kwargs,
):
super().__init__(*args, **kwargs)

self.attrs["name"] = name or "" # Store DataFrame name
self.attrs["description"] = description or "" # Store DataFrame description
self.attrs["column_descriptions"] = {}

if column_descriptions:
column_descriptions = {
k: v for k, v in column_descriptions.items() if k in self.columns
}
if column_descriptions:
self.attrs["column_descriptions"] = column_descriptions
else:
logger.warning(
"None of the column descriptions provided matched the DataFrame columns"
)

def to_markdown(self, index: bool = True, **kwargs):
# Start with DataFrame description if it exists
output = []
if self.attrs["name"]:
output.append(f"Table name: {self.attrs['name']}\n")

if self.attrs["description"]:
output.append(f"Table description: {self.attrs['description']}\n")

if not self.attrs["column_descriptions"]:
output.append(super().to_markdown(index=index, **kwargs))
return "\n".join(output)

desc_row = " | ".join(
(["---"] if index else [])
+ [self.attrs["column_descriptions"].get(col, "") for col in self.columns]
)
original_md = super().to_markdown(index=index, **kwargs)
header_end = original_md.index("\n|")
output.append(
original_md[:header_end] + "\n|" + desc_row + original_md[header_end:]
)
return "\n".join(output)

def head(self, n: int = 5):
out = DataFrameWithMetadata(super().head(n))
out.attrs = self.attrs
return out


if __name__ == "__main__":
# Usage example:
df = DataFrameWithMetadata(
{"a": [1, 2], "b": [3, 4]},
description="Description for the DataFrame",
name="DataFrame Name",
column_descriptions={
"a": "Description for column a",
"b": "Description for column b",
},
)

md = df.to_markdown()
print(md)
md2 = df.to_markdown(index=False)
print(md2)
print("yay!")
# This would raise an error:
# df = DescribedDataFrame({'a': [1]}, descriptions={'nonexistent': 'Description'})
18 changes: 18 additions & 0 deletions wise_pizza/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def explain_changes_in_average(
@param verbose: If set to a truish value, lots of debug info is printed to console
@return: A fitted object
"""

df1 = df1.copy()
df2 = df2.copy()

Expand Down Expand Up @@ -112,6 +113,9 @@ def explain_changes_in_average(
verbose=verbose,
)

if hasattr(df1, "attrs"):
sf.data_attrs = df1.attrs

if hasattr(sf, "pre_total"):
sf.pre_total = avg1
sf.post_total += avg1
Expand Down Expand Up @@ -216,6 +220,8 @@ def explain_changes_in_totals(
cluster_values=cluster_values,
verbose=verbose,
)
if hasattr(df1, "attrs"):
sf_size.data_attrs = df1.attrs

sf_avg = explain_levels(
df=df_avg.data,
Expand All @@ -232,6 +238,9 @@ def explain_changes_in_totals(
verbose=verbose,
)

if hasattr(df1, "attrs"):
sf_avg.data_attrs = df1.attrs

sf_size.final_size = final_size
sf_avg.final_size = final_size
sp = SlicerPair(sf_size, sf_avg)
Expand Down Expand Up @@ -282,6 +291,8 @@ def explain_changes_in_totals(
sf.size_name = size_name
sf.total_name = total_name
sf.average_name = average_name
if hasattr(df1, "attrs"):
sf.data_attrs = df1.attrs
return sf


Expand Down Expand Up @@ -352,6 +363,9 @@ def explain_levels(
cluster_values=cluster_values,
)

if hasattr(df, "attrs"):
sf.data_attrs = df.attrs

for s in sf.segments:
s["naive_avg"] += average
s["total"] += average * s["seg_size"]
Expand Down Expand Up @@ -414,6 +428,7 @@ def explain_timeseries(
assert (
solver == "tree"
), "Only the tree solver is supported for time series at the moment"
attrs = getattr(df, "attrs", None)
df = copy.copy(df)

# replace NaN values in numeric columns with zeros
Expand Down Expand Up @@ -531,6 +546,9 @@ def explain_timeseries(
n_jobs=n_jobs,
)

if hasattr(df, "attrs"):
sf.data_attrs = attrs

# TODO: insert back the normalized bits?
for s in sf.segments:
segment_def = s["segment"]
Expand Down
84 changes: 60 additions & 24 deletions wise_pizza/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import pandas as pd
from scipy.sparse import csc_matrix, diags

from wise_pizza.dataframe_with_metadata import DataFrameWithMetadata
from wise_pizza.solve.find_alpha import find_alpha
from wise_pizza.utils import clean_up_min_max
from wise_pizza.utils import clean_up_min_max, fill_string_na
from wise_pizza.make_matrix import sparse_dummy_matrix
from wise_pizza.cluster import make_clusters
from wise_pizza.preselect import HeuristicSelector
Expand All @@ -27,7 +28,8 @@ def _summary(obj) -> str:
{
k: v
for k, v in s.items()
if k in ["segment", "total", "seg_size", "naive_avg", "impact"]
if k
in ["segment", "total", "seg_size", "naive_avg", "impact", "avg_impact"]
}
for s in obj.segments
],
Expand Down Expand Up @@ -538,8 +540,13 @@ def predict(

@property
def nice_summary(self):

return nice_summary(
self.summary(), self.total_name, self.size_name, self.average_name
self.summary(),
self.total_name,
self.size_name,
self.average_name,
self.data_attrs if hasattr(self, "data_attrs") else None,
)

@property
Expand Down Expand Up @@ -597,37 +604,66 @@ def nice_summary(
total_name: str,
size_name: Optional[str] = None,
average_name: Optional[str] = None,
):
attrs: Optional[Dict[str, str]] = None,
) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]:
x = json.loads(x)
for xx in x["segments"]:
xx.update(xx["segment"])

df = pd.DataFrame(x["segments"]).rename(
columns={"seg_size": size_name, "total": total_name}
)
df["segment"] = df["segment"].apply(
lambda x: str(x).replace("'", "").replace("{", "").replace("}", "")
df = pd.DataFrame(x["segments"])

# These columns are pretty much self-explanatory, don't need descriptions

df = fill_string_na(df, "All")
df = df[[c for c in df.columns if c != "segment"] + ["segment"]]

if not average_name:
average_name = "average " + total_name.replace("total", "").replace(
"Total", ""
).replace("TOTAL", "").replace(" ", " ")

df.rename(
columns={
"seg_size": size_name + " of segment",
"total": total_name + " in segment",
"impact": "Segment impact on overall total",
"avg_impact": f"Segment impact on overall {average_name}",
"naive_avg": average_name + " over segment",
},
inplace=True,
)
if average_name is not None:
df = df.rename(columns={"naive_avg": average_name})

# TODO: more flexible formatting
for col in df.columns:
if col != "segment":
df[col] = df[col].astype(int)
out = {"summary": df, "clusters": x["relevant_clusters"]}
return out

if attrs and "column_descriptions" in attrs:
column_desc = {
k: v for k, v in attrs["column_descriptions"].items() if k in df.columns
}

df = DataFrameWithMetadata(df, column_descriptions=column_desc)

def markdown_summary(x: dict):
table = x["summary"].to_markdown(index=False)
out = df

out = f"""Key segment summary:
{table}"""
# TODO: cast cluster definitions to dataframe too
if "relevant_clusters" in x and x["relevant_clusters"]:
out = {"summary": df, "clusters": x["relevant_clusters"]}

if clusters := x["clusters"]:
out += f"\n\nDefinitions of clusters: {clusters}"
return out


def markdown_summary(x: Union[dict, pd.DataFrame]):
if isinstance(x, pd.DataFrame):
return x.to_markdown(index=False)
elif isinstance(x, dict):
table = x["summary"].to_markdown(index=False)
if "clusters" in x and x["clusters"]:
clusters = x["clusters"]
table += "\n\nDefinitions of clusters: \n"
for k, v in clusters.items():
table += f"\n{k}: {v}"
return table
else:
raise ValueError("Invalid input, expected either a pd.DataFrame or a dict")


class SlicerPair:
def __init__(self, s1: SliceFinder, s2: SliceFinder):
self.s1 = s1
Expand Down
39 changes: 39 additions & 0 deletions wise_pizza/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,42 @@ def clean_up_min_max(min_nonzeros: int = None, max_nonzeros: int = None):

assert min_nonzeros <= max_nonzeros
return min_nonzeros, max_nonzeros


def fill_string_na(df, fill_value=""):
"""
Fill NA values in string-typed columns of a DataFrame with a specified value.
Parameters:
-----------
df : pandas.DataFrame
The input DataFrame
fill_value : str, default=''
The value to use for filling NA values in string columns
Returns:
--------
pandas.DataFrame
A copy of the input DataFrame with NA values filled in string columns
"""
# Create a copy of the DataFrame to avoid modifying the original
df_filled = df.copy()

# Get columns with string (object) or category dtype
string_columns = df_filled.select_dtypes(include=["object", "category"]).columns

# Fill NA values only in string columns
for col in string_columns:
df_filled[col] = df_filled[col].fillna(fill_value)

return df_filled


# Example usage:
# import pandas as pd
# df = pd.DataFrame({
# 'text': ['hello', None, 'world'],
# 'numbers': [1, 2, None],
# 'more_text': [None, 'test', 'data']
# })
# filled_df = fill_string_na(df, fill_value='MISSING')

0 comments on commit 62ddfe8

Please sign in to comment.