Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add static typing with mypy #130

Merged
merged 4 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ jobs:
- name: Check lock file
run: poetry lock --check

- name: Run pre-commit hooks
run: poetry run pre-commit run -a
- name: Run code quality checks
run: poetry run make check

test:
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ install: ## Install the Poetry environment
check: ## Run code quality checks
@echo "Running pre-commit hooks"
@poetry run pre-commit run -a
@poetry run mypy chispa

.PHONY: test
test: ## Run unit tests
Expand Down
29 changes: 15 additions & 14 deletions chispa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
import os
import sys
from glob import glob
from typing import Callable

from pyspark.sql import DataFrame

# Add PySpark to the library path based on the value of SPARK_HOME if pyspark is not already in our path
try:
from pyspark import context # noqa: F401
except ImportError:
# We need to add PySpark, try use findspark, or failback to the "manually" find it
try:
import findspark
import findspark # type: ignore[import-untyped]

findspark.init()
except ImportError:
Expand Down Expand Up @@ -46,28 +49,26 @@


class Chispa:
def __init__(self, formats: FormattingConfig | None = None, default_output=None):
def __init__(self, formats: FormattingConfig | None = None) -> None:
SemyonSinchenko marked this conversation as resolved.
Show resolved Hide resolved
if not formats:
self.formats = FormattingConfig()
elif isinstance(formats, FormattingConfig):
self.formats = formats
else:
self.formats = FormattingConfig._from_arbitrary_dataclass(formats)

self.default_outputs = default_output

def assert_df_equality(
self,
df1,
df2,
ignore_nullable=False,
transforms=None,
allow_nan_equality=False,
ignore_column_order=False,
ignore_row_order=False,
underline_cells=False,
ignore_metadata=False,
):
df1: DataFrame,
df2: DataFrame,
ignore_nullable: bool = False,
transforms: list[Callable] | None = None, # type: ignore[type-arg]
allow_nan_equality: bool = False,
ignore_column_order: bool = False,
ignore_row_order: bool = False,
underline_cells: bool = False,
ignore_metadata: bool = False,
) -> None:
return assert_df_equality(
df1,
df2,
Expand Down
2 changes: 1 addition & 1 deletion chispa/bcolors.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class bcolors:
Bold = "\033[1m"
Underline = "\033[4m"

def __init__(self):
def __init__(self) -> None:
warnings.warn("The `bcolors` class is deprecated and will be removed in a future version.", DeprecationWarning)


Expand Down
25 changes: 13 additions & 12 deletions chispa/column_comparer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from prettytable import PrettyTable
from pyspark.sql import DataFrame

from chispa.formatting import blue

Expand All @@ -11,12 +12,12 @@ class ColumnsNotEqualError(Exception):
pass


def assert_column_equality(df, col_name1, col_name2):
elements = df.select(col_name1, col_name2).collect()
colName1Elements = list(map(lambda x: x[0], elements))
colName2Elements = list(map(lambda x: x[1], elements))
if colName1Elements != colName2Elements:
zipped = list(zip(colName1Elements, colName2Elements))
def assert_column_equality(df: DataFrame, col_name1: str, col_name2: str) -> None:
rows = df.select(col_name1, col_name2).collect()
col_name_1_elements = [x[0] for x in rows]
col_name_2_elements = [x[1] for x in rows]
if col_name_1_elements != col_name_2_elements:
zipped = list(zip(col_name_1_elements, col_name_2_elements))
t = PrettyTable([col_name1, col_name2])
for elements in zipped:
if elements[0] == elements[1]:
Expand All @@ -26,18 +27,18 @@ def assert_column_equality(df, col_name1, col_name2):
raise ColumnsNotEqualError("\n" + t.get_string())


def assert_approx_column_equality(df, col_name1, col_name2, precision):
elements = df.select(col_name1, col_name2).collect()
colName1Elements = list(map(lambda x: x[0], elements))
colName2Elements = list(map(lambda x: x[1], elements))
def assert_approx_column_equality(df: DataFrame, col_name1: str, col_name2: str, precision: float) -> None:
rows = df.select(col_name1, col_name2).collect()
col_name_1_elements = [x[0] for x in rows]
col_name_2_elements = [x[1] for x in rows]
all_rows_equal = True
zipped = list(zip(colName1Elements, colName2Elements))
zipped = list(zip(col_name_1_elements, col_name_2_elements))
t = PrettyTable([col_name1, col_name2])
for elements in zipped:
first = blue(str(elements[0]))
second = blue(str(elements[1]))
# when one is None and the other isn't, they're not equal
if (elements[0] is None and elements[1] is not None) or (elements[0] is not None and elements[1] is None):
if (elements[0] is None) != (elements[1] is None):
all_rows_equal = False
t.add_row([str(elements[0]), str(elements[1])])
# when both are None, they're equal
Expand Down
55 changes: 30 additions & 25 deletions chispa/dataframe_comparer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

from functools import reduce
from typing import Callable

from pyspark.sql import DataFrame

from chispa.formatting import FormattingConfig
from chispa.row_comparer import are_rows_approx_equal, are_rows_equal_enhanced
Expand All @@ -18,17 +21,17 @@ class DataFramesNotEqualError(Exception):


def assert_df_equality(
df1,
df2,
ignore_nullable=False,
transforms=None,
allow_nan_equality=False,
ignore_column_order=False,
ignore_row_order=False,
underline_cells=False,
ignore_metadata=False,
df1: DataFrame,
df2: DataFrame,
ignore_nullable: bool = False,
transforms: list[Callable] | None = None, # type: ignore[type-arg]
allow_nan_equality: bool = False,
ignore_column_order: bool = False,
ignore_row_order: bool = False,
underline_cells: bool = False,
ignore_metadata: bool = False,
formats: FormattingConfig | None = None,
):
) -> None:
if not formats:
formats = FormattingConfig()
elif not isinstance(formats, FormattingConfig):
Expand All @@ -48,7 +51,7 @@ def assert_df_equality(
df1.collect(),
df2.collect(),
are_rows_equal_enhanced,
[True],
{"allow_nan_equality": True},
underline_cells=underline_cells,
formats=formats,
)
Expand All @@ -61,7 +64,7 @@ def assert_df_equality(
)


def are_dfs_equal(df1, df2):
def are_dfs_equal(df1: DataFrame, df2: DataFrame) -> bool:
if df1.schema != df2.schema:
return False
if df1.collect() != df2.collect():
Expand All @@ -70,16 +73,16 @@ def are_dfs_equal(df1, df2):


def assert_approx_df_equality(
df1,
df2,
precision,
ignore_nullable=False,
transforms=None,
allow_nan_equality=False,
ignore_column_order=False,
ignore_row_order=False,
df1: DataFrame,
df2: DataFrame,
precision: float,
ignore_nullable: bool = False,
transforms: list[Callable] | None = None, # type: ignore[type-arg]
allow_nan_equality: bool = False,
ignore_column_order: bool = False,
ignore_row_order: bool = False,
formats: FormattingConfig | None = None,
):
) -> None:
if not formats:
formats = FormattingConfig()
elif not isinstance(formats, FormattingConfig):
Expand All @@ -99,10 +102,12 @@ def assert_approx_df_equality(
df1.collect(),
df2.collect(),
are_rows_approx_equal,
[precision, allow_nan_equality],
formats,
{"precision": precision, "allow_nan_equality": allow_nan_equality},
formats=formats,
)
elif allow_nan_equality:
assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], formats)
assert_generic_rows_equality(
df1.collect(), df2.collect(), are_rows_equal_enhanced, {"allow_nan_equality": True}, formats=formats
)
else:
assert_basic_rows_equality(df1.collect(), df2.collect(), formats)
assert_basic_rows_equality(df1.collect(), df2.collect(), formats=formats)
2 changes: 1 addition & 1 deletion chispa/default_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class DefaultFormats:
mismatched_cells: list[str] = field(default_factory=lambda: ["red", "underline"])
matched_cells: list[str] = field(default_factory=lambda: ["blue"])

def __post_init__(self):
def __post_init__(self) -> None:
warnings.warn(
"DefaultFormats is deprecated. Use `chispa.formatting.FormattingConfig` instead.", DeprecationWarning
)
2 changes: 1 addition & 1 deletion chispa/formatting/format_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ def format_string(input_string: str, format: Format) -> str:
return formatted_string


def blue(string: str):
def blue(string: str) -> str:
return Color.LIGHT_BLUE + string + Color.LIGHT_RED
9 changes: 6 additions & 3 deletions chispa/formatting/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Format:
style: list[Style] | None = None

@classmethod
def from_dict(cls, format_dict: dict) -> Format:
def from_dict(cls, format_dict: dict[str, str | list[str]]) -> Format:
"""
Create a Format instance from a dictionary.

Expand All @@ -72,7 +72,10 @@ def from_dict(cls, format_dict: dict) -> Format:
if invalid_keys:
raise ValueError(f"Invalid keys in format dictionary: {invalid_keys}. Valid keys are {valid_keys}")

color = cls._get_color_enum(format_dict.get("color"))
if isinstance(format_dict.get("color"), list):
raise TypeError("The value for key 'color' should be a string, not a list!")
color = cls._get_color_enum(format_dict.get("color")) # type: ignore[arg-type]

style = format_dict.get("style")
if isinstance(style, str):
styles = [cls._get_style_enum(style)]
Expand All @@ -81,7 +84,7 @@ def from_dict(cls, format_dict: dict) -> Format:
else:
styles = None

return cls(color=color, style=styles)
return cls(color=color, style=styles) # type: ignore[arg-type]

@classmethod
def from_list(cls, values: list[str]) -> Format:
Expand Down
10 changes: 5 additions & 5 deletions chispa/formatting/formatting_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ class FormattingConfig:

def __init__(
self,
mismatched_rows: Format | dict = Format(Color.RED),
matched_rows: Format | dict = Format(Color.BLUE),
mismatched_cells: Format | dict = Format(Color.RED, [Style.UNDERLINE]),
matched_cells: Format | dict = Format(Color.BLUE),
mismatched_rows: Format | dict[str, str | list[str]] = Format(Color.RED),
matched_rows: Format | dict[str, str | list[str]] = Format(Color.BLUE),
mismatched_cells: Format | dict[str, str | list[str]] = Format(Color.RED, [Style.UNDERLINE]),
matched_cells: Format | dict[str, str | list[str]] = Format(Color.BLUE),
):
"""
Initializes the FormattingConfig with given or default formatting.
Expand All @@ -46,7 +46,7 @@ def __init__(
self.mismatched_cells: Format = self._parse_format(mismatched_cells)
self.matched_cells: Format = self._parse_format(matched_cells)

def _parse_format(self, format: Format | dict) -> Format:
def _parse_format(self, format: Format | dict[str, str | list[str]]) -> Format:
if isinstance(format, Format):
return format
elif isinstance(format, dict):
Expand Down
8 changes: 5 additions & 3 deletions chispa/number_helpers.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from __future__ import annotations

import math
from decimal import Decimal
from typing import Any


def isnan(x):
def isnan(x: Any) -> bool:
try:
return math.isnan(x)
except TypeError:
return False


def nan_safe_equality(x, y) -> bool:
def nan_safe_equality(x: int | float, y: int | float | Decimal) -> bool:
return (x == y) or (isnan(x) and isnan(y))


def nan_safe_approx_equality(x, y, precision) -> bool:
def nan_safe_approx_equality(x: int | float, y: int | float, precision: float | Decimal) -> bool:
return (abs(x - y) <= precision) or (isnan(x) and isnan(y))
8 changes: 4 additions & 4 deletions chispa/row_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ def are_rows_equal(r1: Row, r2: Row) -> bool:
return r1 == r2


def are_rows_equal_enhanced(r1: Row, r2: Row, allow_nan_equality: bool) -> bool:
def are_rows_equal_enhanced(r1: Row | None, r2: Row | None, allow_nan_equality: bool) -> bool:
if r1 is None and r2 is None:
return True
if (r1 is None and r2 is not None) or (r2 is None and r1 is not None):
if r1 is None or r2 is None:
return False
d1 = r1.asDict()
d2 = r2.asDict()
Expand All @@ -27,10 +27,10 @@ def are_rows_equal_enhanced(r1: Row, r2: Row, allow_nan_equality: bool) -> bool:
return r1 == r2


def are_rows_approx_equal(r1: Row, r2: Row, precision: float, allow_nan_equality=False) -> bool:
def are_rows_approx_equal(r1: Row | None, r2: Row | None, precision: float, allow_nan_equality: bool = False) -> bool:
if r1 is None and r2 is None:
return True
if (r1 is None and r2 is not None) or (r2 is None and r1 is not None):
if r1 is None or r2 is None:
return False
d1 = r1.asDict()
d2 = r2.asDict()
Expand Down
Loading
Loading