Skip to content

Commit

Permalink
Merge pull request #261 from moj-analytical-services/sql_parse
Browse files Browse the repository at this point in the history
Parse SQL case expressions
  • Loading branch information
RobinL authored Jan 11, 2022
2 parents d13a84d + 71b3084 commit 229d756
Show file tree
Hide file tree
Showing 22 changed files with 691 additions and 381 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [2.1.0]

### Added

- `sql_expr` now added to tooltips on bayes factor chart, displaying the SQL expression for each comparison level
- Warnings to the user if they don't include a null level in their case expression, custom columns is different to cols used in case expression,
- `splink` now parses `case_expression` to auto-populate `num_levels` or `col_name` or `custom_columns_used`. The user may still provide this information, but is no longer required to.

Note that Splink now has a depedency on `sqlglot`, a no-dependency SQL parser.

## [2.0.4]

### Added
Expand Down
4 changes: 2 additions & 2 deletions Dockerfile_testrunner
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
FROM mamonu/moj-spark-jovyan:baseenv

RUN pip install pytest pytest-cov poetry coveralls typeguard
RUN pip install --no-dependencies splink-data-generation==1.0.0
RUN pip install pytest pytest-cov poetry coveralls typeguard sqlglot
RUN pip install --no-dependencies splink-data-generation==1.0.1

ADD . /myfiles
WORKDIR /myfiles
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "splink"
version = "2.0.4"
version = "2.1.0"
description = "Implementation of Fellegi-Sunter's canonical model of record linkage in Apache Spark, including EM algorithm to estimate parameters"
authors = ["Robin Linacre <[email protected]>", "Sam Lindsay", "Theodore Manassis"]
license = "MIT"
Expand All @@ -12,11 +12,12 @@ readme = "README.md"
python = "^3.6"
jsonschema = "^3.2"
typeguard = "^2.10.0"
sqlglot = "^1.17.1"

[tool.poetry.dev-dependencies]
pytest = "^5.3"
pandas = "^1.0.0"
splink-data-generation = "^0.2.1"
splink-data-generation = "^1.0.1"

[build-system]
requires = ["poetry>=0.12"]
Expand Down
2 changes: 1 addition & 1 deletion splink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
):
"""Splink data linker
Provides easy access to the core user-facing functinoality of splink
Provides easy access to the core user-facing functionality of splink
Args:
settings (dict): splink settings dictionary
Expand Down
98 changes: 93 additions & 5 deletions splink/default_settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings

from splink.settings import ComparisonColumn
from pyspark.sql.session import SparkSession

from copy import deepcopy
Expand All @@ -20,6 +20,12 @@
_add_as_gamma_to_case_statement,
)

from .parse_case_statement import (
parse_case_statement,
generate_sql_from_parsed_case_expr,
get_columns_used_from_sql_without_l_r_suffix,
)


def _normalise_prob_list(prob_array: list):
sum_list = sum(prob_array)
Expand Down Expand Up @@ -101,8 +107,12 @@ def _get_default_probabilities(m_or_u, levels):

def _complete_case_expression(col_settings, spark):

cc = ComparisonColumn(col_settings)
if cc.has_case_expression_or_comparison_levels:
return col_settings

default_case_statements = _get_default_case_statements_functions(spark)
levels = col_settings["num_levels"]
levels = cc.num_levels

if "custom_name" in col_settings:
col_name_for_case_fn = col_settings["custom_name"]
Expand Down Expand Up @@ -136,7 +146,8 @@ def _complete_probabilities(col_settings: dict, mu_probabilities: str):
"""

if mu_probabilities not in col_settings:
levels = col_settings["num_levels"]
cc = ComparisonColumn(col_settings)
levels = cc.num_levels
probs = _get_default_probabilities(mu_probabilities, levels)
col_settings[mu_probabilities] = probs

Expand All @@ -149,11 +160,72 @@ def _complete_tf_adjustment_weights(col_settings: dict):
f"All values of 'tf_adjustment_weights' must be between 0 and 1"
)
else:
weights = [0.0] * col_settings["num_levels"]
cc = ComparisonColumn(col_settings)

weights = [0.0] * cc.num_levels
weights[-1] = 1.0
col_settings["tf_adjustment_weights"] = weights


def _complete_comparison_levels(col_settings):
if "comparison_levels" not in col_settings:
case_expression = col_settings["case_expression"]
col_settings["comparison_levels"] = parse_case_statement(case_expression)

if "case_expression" not in col_settings:
cl = col_settings["comparison_levels"]
col_settings["case_expression"] = generate_sql_from_parsed_case_expr(cl)

from splink.settings import ComparisonColumn

cc = ComparisonColumn(col_settings)
keys = cc.comparison_levels_dict.keys()
if "-1" not in keys:

warnings.warn(
"No -1 level found in case statement."
" You usually want to use -1 as the level for the null value."
" e.g. WHEN col_l is null or col_r is null then -1"
f" Case statement is:\n {col_settings['case_expression']}."
)


def _complete_col_name(col_settings):

if "custom_name" in col_settings:
return

if "col_name" in col_settings:
return

sql = generate_sql_from_parsed_case_expr(col_settings["comparison_levels"])
sql_cols = get_columns_used_from_sql_without_l_r_suffix(sql)
if len(sql_cols) == 1:
col_settings["col_name"] = sql_cols[0]
else:
col_settings["custom_name"] = "_".join(sql_cols)
return col_settings


def _complete_custom_columns(col_settings):

if "col_name" in col_settings:
return col_settings

if "custom_name" in col_settings:
sql = generate_sql_from_parsed_case_expr(col_settings["comparison_levels"])
sql_cols = get_columns_used_from_sql_without_l_r_suffix(sql)
if "columns_used" in col_settings:
if set(sql_cols) != set(col_settings["columns_used"]):
warnings.warn(
f"The columns used in the case statement are {sql_cols} but the columns "
f"specified in the settings dictionary are {col_settings['columns_used']}"
)
else:
col_settings["custom_columns_used"] = sql_cols
return col_settings


def complete_settings_dict(settings_dict: dict, spark: SparkSession):
"""Auto-populate any missing settings from the settings dictionary using the 'sensible defaults' that
are specified in the json schema (./splink/files/settings_jsonschema.json)
Expand Down Expand Up @@ -203,7 +275,6 @@ def complete_settings_dict(settings_dict: dict, spark: SparkSession):

# Populate non-existing keys from defaults
keys_for_defaults = [
"num_levels",
"data_type",
"term_frequency_adjustments",
"fix_u_probabilities",
Expand All @@ -215,10 +286,27 @@ def complete_settings_dict(settings_dict: dict, spark: SparkSession):
default = get_default_value_from_schema(key, is_column_setting=True)
col_settings[key] = default

# Populate default value for num levels only if case_expression or comparison_levels is not specified
skip_if_present = set(["case_expression", "comparison_levels", "num_levels"])
keys = set(col_settings.keys())
intersect = keys.intersection(skip_if_present)
if len(intersect) == 0:
default = get_default_value_from_schema(
"num_levels", is_column_setting=True
)
col_settings["num_levels"] = default

# Doesn't need assignment because we're modify the col_settings dictionary

_complete_case_expression(col_settings, spark)

_complete_comparison_levels(col_settings)
_complete_col_name(col_settings)
_complete_custom_columns(col_settings)

_complete_probabilities(col_settings, "m_probabilities")
_complete_probabilities(col_settings, "u_probabilities")

_complete_tf_adjustment_weights(col_settings)

return settings_dict
Expand Down
6 changes: 1 addition & 5 deletions splink/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@
from typeguard import typechecked

from .charts import load_chart_definition, altair_if_installed_else_json
from .settings import complete_settings_dict, Settings
from .vertically_concat import vertically_concatenate_datasets
from .blocking import block_using_rules
from .gammas import add_gammas
from .estimate import _num_target_rows_to_rows_to_sample
from .settings import Settings


def _equal_spaced_buckets(num_buckets, extent):
Expand Down
2 changes: 1 addition & 1 deletion splink/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .gammas import add_gammas
from .maximisation_step import run_maximisation_step
from .model import Model
from .settings import complete_settings_dict
from .default_settings import complete_settings_dict
from .vertically_concat import vertically_concatenate_datasets

import warnings
Expand Down
Loading

0 comments on commit 229d756

Please sign in to comment.