Skip to content

Commit

Permalink
All unit tests passing locally
Browse files Browse the repository at this point in the history
  • Loading branch information
Will Langdale committed Nov 1, 2024
1 parent e2ba431 commit 2d12264
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 56 deletions.
11 changes: 6 additions & 5 deletions src/matchbox/server/postgresql/utils/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ModelsFrom,
Probabilities,
)
from matchbox.server.postgresql.utils.query import hash_to_hex_decode


class SourceInfo(NamedTuple):
Expand Down Expand Up @@ -67,9 +68,9 @@ def _get_source_info(engine: Engine, model_hash: bytes) -> SourceInfo:
left = session.get(Models, left_hash)
right = session.get(Models, right_hash) if right_hash else None

left_ancestors = {left_hash} | {hash for hash in left.ancestors}
left_ancestors = {left_hash} | {m.hash for m in left.ancestors}
if right:
right_ancestors = {right_hash} | {hash for hash in right.ancestors}
right_ancestors = {right_hash} | {m.hash for m in right.ancestors}
else:
right_ancestors = None

Expand Down Expand Up @@ -180,7 +181,7 @@ def get_model_probabilities(engine: Engine, model: Models) -> ProbabilityResults
# First get all clusters this model assigned probabilities to
model_clusters = (
select(Probabilities.cluster)
.where(Probabilities.model == model.hash)
.where(Probabilities.model == hash_to_hex_decode(model.hash))
.cte("model_clusters")
)

Expand All @@ -199,7 +200,7 @@ def get_model_probabilities(engine: Engine, model: Models) -> ProbabilityResults
Probabilities,
and_(
Probabilities.cluster == Contains.parent,
Probabilities.model == model.hash,
Probabilities.model == hash_to_hex_decode(model.hash),
),
)
.where(~Contains.child.in_(select(model_parents)))
Expand Down Expand Up @@ -235,7 +236,7 @@ def get_model_probabilities(engine: Engine, model: Models) -> ProbabilityResults
Probabilities,
and_(
Probabilities.cluster == Contains.parent,
Probabilities.model == model.hash,
Probabilities.model == hash_to_hex_decode(model.hash),
),
)
.group_by(Contains.parent)
Expand Down
12 changes: 5 additions & 7 deletions test/client/test_dedupers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
from matchbox import make_model, query
from matchbox.helpers import selector
from matchbox.server.models import Source, SourceWarehouse
from matchbox.server.postgresql import MatchboxPostgres
from pandas import DataFrame
Expand Down Expand Up @@ -41,7 +40,10 @@ def test_dedupers(

db_add_indexed_data(backend=matchbox_postgres, warehouse_data=warehouse_data)

df: DataFrame = request.getfixturevalue(fx_data.fixture)
select: dict[Source, list[str]]
df: DataFrame

select, df = request.getfixturevalue(fx_data.fixture)

fields = list(fx_data.fields.keys())

Expand Down Expand Up @@ -115,11 +117,7 @@ def test_dedupers(
model.truth = 0.0

clusters = query(
selector=selector(
table=fx_data.source,
fields=list(fx_data.fields.values()),
engine=warehouse.engine,
),
selector=select,
backend=matchbox_postgres,
return_type="pandas",
model=deduper_name,
Expand Down
89 changes: 60 additions & 29 deletions test/client/test_linkers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from matchbox import make_model, query
from matchbox.helpers import selector, selectors
from matchbox.helpers import selectors
from matchbox.server.models import Source, SourceWarehouse
from matchbox.server.postgresql import MatchboxPostgres
from pandas import DataFrame
Expand Down Expand Up @@ -51,8 +51,13 @@ def test_linkers(
request=request,
)

df_l: DataFrame = request.getfixturevalue(fx_data.fixture_l)
df_r: DataFrame = request.getfixturevalue(fx_data.fixture_r)
select_l: dict[Source, list[str]]
select_r: dict[Source, list[str]]
df_l: DataFrame
df_r: DataFrame

select_l, df_l = request.getfixturevalue(fx_data.fixture_l)
select_r, df_r = request.getfixturevalue(fx_data.fixture_r)

fields_l = list(fx_data.fields_l.keys())
fields_r = list(fx_data.fields_r.keys())
Expand Down Expand Up @@ -102,11 +107,10 @@ def test_linkers(
right_source=fx_data.source_r,
)

linked = model.run()

linked_df = linked.to_df()
results = model.run()

linked_df_with_source = linked.inspect_with_source(
linked_df = results.probabilities.to_df()
linked_df_with_source = results.probabilities.inspect_with_source(
left_data=df_l,
left_key="hash",
right_data=df_r,
Expand All @@ -120,29 +124,60 @@ def test_linkers(
for field_l, field_r in zip(fields_l, fields_r, strict=True):
assert linked_df_with_source[field_l].equals(linked_df_with_source[field_r])

# 3. Linked probabilities are inserted correctly
# 3. Correct number of clusters are resolved

linked.to_matchbox(backend=matchbox_postgres)
clusters_links_df = results.clusters.to_df()
clusters_links_df_with_source = results.clusters.inspect_with_source(
left_data=df_l,
left_key="hash",
right_data=df_r,
right_key="hash",
)

model = matchbox_postgres.get_model(model=linker_name)
assert model.probabilities.count() == fx_data.tgt_prob_n
assert isinstance(clusters_links_df, DataFrame)
assert clusters_links_df.parent.nunique() == fx_data.tgt_clus_n

# 4. Correct number of clusters are resolved and inserted correctly
assert isinstance(clusters_links_df_with_source, DataFrame)
for field_l, field_r in zip(fields_l, fields_r, strict=True):
# When we enrich the ClusterResults in a deduplication job, every child
# hash will match something in the source data, because we're only using
# one dataset. NaNs are therefore impossible.
# When we enrich the ClusterResults in a link job, some child hashes
# will match something in the left data, and others in the right data.
# NaNs are therefore guaranteed.
# We therefore coalesce by parent to unique joined values, which
# we can expect to equal the target cluster number, and have matching
# rows of data
def unique_non_null(s):
return s.dropna().unique()

cluster_vals = (
clusters_links_df_with_source.filter(["parent", field_l, field_r])
.groupby("parent")
.agg(
{
field_l: unique_non_null,
field_r: unique_non_null,
}
)
.explode(column=[field_l, field_r])
.reset_index()
)

assert cluster_vals[field_l].equals(cluster_vals[field_r])
assert cluster_vals.parent.nunique() == fx_data.tgt_clus_n
assert cluster_vals.shape[0] == fx_data.tgt_clus_n

# 4. Probabilities and clusters are inserted correctly

results.to_matchbox(backend=matchbox_postgres)

model = matchbox_postgres.get_model(model=linker_name)
assert model.probabilities.dataframe.shape[0] == fx_data.tgt_prob_n

model.truth = 0.0

l_r_selector = selectors(
selector(
table=fx_data.source_l,
fields=list(fx_data.fields_l.values()),
engine=warehouse.engine,
),
selector(
table=fx_data.source_r,
fields=list(fx_data.fields_r.values()),
engine=warehouse.engine,
),
)
l_r_selector = selectors(select_l, select_r)

clusters = query(
selector=l_r_selector,
Expand All @@ -152,8 +187,4 @@ def test_linkers(
)

assert isinstance(clusters, DataFrame)
assert clusters.hash.nunique() == fx_data.tgt_clus_n

model = matchbox_postgres.get_model(model=linker_name)

assert model.clusters.count() == fx_data.unique_n
assert clusters.hash.nunique() == fx_data.unique_n
30 changes: 18 additions & 12 deletions test/fixtures/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def cdms_companies(all_companies: DataFrame) -> DataFrame:
@pytest.fixture(scope="function")
def query_clean_crn(
matchbox_postgres: MatchboxPostgres, warehouse_data: list[Source]
) -> DataFrame:
) -> tuple[dict[Source, list[str]], DataFrame]:
"""Fixture for CRN data, and the selector used to get it."""
# Select
crn_wh = warehouse_data[0]
select_crn = selector(
Expand All @@ -140,13 +141,14 @@ def query_clean_crn(

crn_cleaned = process(data=crn, pipeline=cleaner_crn)

return crn_cleaned
return select_crn, crn_cleaned


@pytest.fixture(scope="function")
def query_clean_duns(
matchbox_postgres: MatchboxPostgres, warehouse_data: list[Source]
) -> DataFrame:
) -> tuple[dict[Source, list[str]], DataFrame]:
"""Fixture for DUNS data, and the selector used to get it."""
# Select
duns_wh = warehouse_data[1]
select_duns = selector(
Expand All @@ -170,13 +172,14 @@ def query_clean_duns(

duns_cleaned = process(data=duns, pipeline=cleaner_duns)

return duns_cleaned
return select_duns, duns_cleaned


@pytest.fixture(scope="function")
def query_clean_cdms(
matchbox_postgres: MatchboxPostgres, warehouse_data: list[Source]
) -> DataFrame:
) -> tuple[dict[Source, list[str]], DataFrame]:
"""Fixture for CDMS data, and the selector used to get it."""
# Select
cdms_wh = warehouse_data[2]
select_cdms = selector(
Expand All @@ -192,13 +195,14 @@ def query_clean_cdms(
)

# No cleaning needed, see original data
return cdms
return select_cdms, cdms


@pytest.fixture(scope="function")
def query_clean_crn_deduped(
matchbox_postgres: MatchboxPostgres, warehouse_data: list[Source]
) -> DataFrame:
) -> tuple[dict[Source, list[str]], DataFrame]:
"""Fixture for cleaned, deduped CRN data, and the selector used to get it."""
# Select
crn_wh = warehouse_data[0]
select_crn = selector(
Expand All @@ -222,13 +226,14 @@ def query_clean_crn_deduped(

crn_cleaned = process(data=crn, pipeline=cleaner_crn)

return crn_cleaned
return select_crn, crn_cleaned


@pytest.fixture(scope="function")
def query_clean_duns_deduped(
matchbox_postgres: MatchboxPostgres, warehouse_data: list[Source]
) -> DataFrame:
) -> tuple[dict[Source, list[str]], DataFrame]:
"""Fixture for cleaned, deduped DUNS data, and the selector used to get it."""
# Select
duns_wh = warehouse_data[1]
select_duns = selector(
Expand All @@ -252,13 +257,14 @@ def query_clean_duns_deduped(

duns_cleaned = process(data=duns, pipeline=cleaner_duns)

return duns_cleaned
return select_duns, duns_cleaned


@pytest.fixture(scope="function")
def query_clean_cdms_deduped(
matchbox_postgres: MatchboxPostgres, warehouse_data: list[Source]
) -> DataFrame:
) -> tuple[dict[Source, list[str]], DataFrame]:
"""Fixture for cleaned, deduped CDMS data, and the selector used to get it."""
# Select
cdms_wh = warehouse_data[2]
select_cdms = selector(
Expand All @@ -274,4 +280,4 @@ def query_clean_cdms_deduped(
)

# No cleaning needed, see original data
return cdms
return select_cdms, cdms
6 changes: 3 additions & 3 deletions test/fixtures/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _db_add_dedupe_models_and_data(

for fx_data in dedupe_data:
for fx_deduper in dedupe_models:
df = request.getfixturevalue(fx_data.fixture)
_, df = request.getfixturevalue(fx_data.fixture)

deduper_name = f"{fx_deduper.name}_{fx_data.source}"
deduper_settings = fx_deduper.build_settings(fx_data)
Expand Down Expand Up @@ -137,8 +137,8 @@ def _db_add_link_models_and_data(

for fx_data in link_data:
for fx_linker in link_models:
df_l = request.getfixturevalue(fx_data.fixture_l)
df_r = request.getfixturevalue(fx_data.fixture_r)
_, df_l = request.getfixturevalue(fx_data.fixture_l)
_, df_r = request.getfixturevalue(fx_data.fixture_r)

linker_name = f"{fx_linker.name}_{fx_data.source_l}_{fx_data.source_r}"
linker_settings = fx_linker.build_settings(fx_data)
Expand Down

0 comments on commit 2d12264

Please sign in to comment.