Skip to content

Commit

Permalink
Very close to having extraction working
Browse files Browse the repository at this point in the history
  • Loading branch information
Will Langdale committed Oct 18, 2024
1 parent 899bbce commit 72c7e61
Show file tree
Hide file tree
Showing 15 changed files with 297 additions and 400 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies = [
"matplotlib>=3.9.2",
"pandas>=2.2.3",
"pg-bulk-ingest>=0.0.54",
"psycopg2-binary>=2.9.9",
"psycopg2>=2.9.10",
"pyarrow>=17.0.0",
"pydantic-settings>=2.5.2",
"pydantic>=2.9.2",
Expand Down
8 changes: 6 additions & 2 deletions src/matchbox/common/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
columns_to_value_ordered_hash,
list_to_value_ordered_hash,
)
from matchbox.server.base import MatchboxDBAdapter
from matchbox.server.base import MatchboxDBAdapter, inject_backend
from matchbox.server.models import Cluster, Probability
from pandas import DataFrame, concat
from pydantic import BaseModel, ConfigDict, model_validator
Expand Down Expand Up @@ -125,6 +125,7 @@ def to_df(self) -> DataFrame:

return df

@inject_backend
def to_records(self, backend: MatchboxDBAdapter | None) -> list[Probability]:
"""Returns the results as a list of records suitable for insertion.
Expand Down Expand Up @@ -163,6 +164,7 @@ def to_records(self, backend: MatchboxDBAdapter | None) -> list[Probability]:
].to_numpy()
]

@inject_backend
def to_matchbox(self, backend: MatchboxDBAdapter) -> None:
"""Writes the results to the Matchbox database."""
backend.insert_model(
Expand All @@ -175,7 +177,7 @@ def to_matchbox(self, backend: MatchboxDBAdapter) -> None:
model = backend.get_model(model=self.run_name)

model.insert_probabilities(
probabilites=self.to_records(),
probabilities=self.to_records(backend=backend),
probability_type="links" if self.left != self.right else "deduplications",
batch_size=backend.settings.batch_size,
)
Expand Down Expand Up @@ -237,12 +239,14 @@ def to_records(self) -> list[Cluster]:
parent_child_pairs = self.dataframe[["parent", "child"]].values
return [Cluster(parent=row[0], child=row[1]) for row in parent_child_pairs]

@inject_backend
def to_matchbox(self, backend: MatchboxDBAdapter) -> None:
"""Writes the results to the Matchbox database."""
model = backend.get_model(model=self.run_name)
model.insert_clusters(
clusters=self.to_records(),
batch_size=backend.settings.batch_size,
cluster_type="links" if self.left != self.right else "deduplications",
)


Expand Down
3 changes: 2 additions & 1 deletion src/matchbox/helpers/visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from matplotlib.figure import Figure
from rustworkx.visualization import mpl_draw

from matchbox.server.base import MatchboxDBAdapter
from matchbox.server.base import MatchboxDBAdapter, inject_backend


@inject_backend
def draw_model_tree(backend: MatchboxDBAdapter) -> Figure:
"""
Draws the model subgraph.
Expand Down
12 changes: 11 additions & 1 deletion src/matchbox/server/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from abc import ABC, abstractmethod
from enum import StrEnum
from functools import wraps
Expand Down Expand Up @@ -196,13 +197,22 @@ def inject_backend(func: Callable) -> Callable:
Used to allow user-facing functions to access the backend without needing to
pass it in. The backend is defined by the MB__BACKEND_TYPE environment variable.
Can be used for both functions and methods.
If the user specifies a backend, it will be used instead of the injection.
"""

@wraps(func)
def _inject_backend(*args, backend: "MatchboxDBAdapter | None" = None, **kwargs):
if backend is None:
backend = BackendManager.get_backend()
return func(backend, *args, **kwargs)

sig = inspect.signature(func)
params = list(sig.parameters.values())

if params and params[0].name in ("self", "cls"):
return func(args[0], backend, *args[1:], **kwargs)
else:
return func(backend, *args, **kwargs)

return _inject_backend
2 changes: 1 addition & 1 deletion src/matchbox/server/postgresql/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def insert_model(
)
insert_deduper(
model=model,
deduplicates=deduplicates,
deduplicates=str(deduplicates),
description=description,
engine=MBDB.get_engine(),
)
Expand Down
15 changes: 10 additions & 5 deletions src/matchbox/server/postgresql/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pg_bulk_ingest import Delete, Upsert, ingest
from sqlalchemy import Engine, Table
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import Session
from sqlalchemy.orm import DeclarativeMeta, Session

from matchbox.server.postgresql.data import SourceDataset
from matchbox.server.postgresql.models import Models, ModelsFrom
Expand Down Expand Up @@ -87,25 +87,30 @@ def batched(iterable: Iterable, n: int) -> Iterable:


def data_to_batch(
records: list[dict], table: Table, batch_size: int
records: list[tuple], table: Table, batch_size: int
) -> Callable[[str], Tuple[Any]]:
"""Constructs a batches function for any dataframe and table."""

def _batches() -> Iterable[Tuple[None, None, Iterable[Tuple[Table, dict]]]]:
def _batches(
high_watermark, # noqa ARG001 required for pg_bulk_ingest
) -> Iterable[Tuple[None, None, Iterable[Tuple[Table, tuple]]]]:
for batch in batched(records, batch_size):
yield None, None, ((table, t) for t in batch)

return _batches


def batch_ingest(
records: list[dict],
table: Table,
records: list[tuple],
table: Table | DeclarativeMeta,
conn: Connection,
batch_size: int,
) -> None:
"""Batch ingest records into a database table."""

if isinstance(table, DeclarativeMeta):
table = table.__table__

fn_batch = data_to_batch(
records=records,
table=table,
Expand Down
51 changes: 6 additions & 45 deletions src/matchbox/server/postgresql/utils/insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,33 +152,17 @@ def insert_probabilities(
)

# Upsert nodes
def probability_to_node(probability: Probability) -> dict:
return {
"sha1": probability.sha1,
"left": probability.left,
"right": probability.right,
}

batch_ingest(
records=[probability_to_node(prob) for prob in probabilities],
records=[(prob.left, prob.right, prob.sha1) for prob in probabilities],
table=NodesTable,
conn=conn,
batch_size=batch_size,
)

# Insert probabilities
def probability_to_probability(
probability: Probability, model_hash: bytes
) -> dict:
return {
"ddupe" if is_deduper else "link": probability.sha1,
"model": model_hash,
"probability": probability.probability,
}

batch_ingest(
records=[
probability_to_probability(prob, model_hash) for prob in probabilities
(prob.sha1, model_hash, prob.probability) for prob in probabilities
],
table=ProbabilitiesTable,
conn=conn,
Expand Down Expand Up @@ -243,47 +227,24 @@ def insert_clusters(
)

# Upsert cluster nodes
def cluster_to_cluster(cluster: Cluster) -> dict:
"""Prepares a Cluster for the Clusters table."""
return {
"sha1": cluster.parent,
}

batch_ingest(
records=list({cluster_to_cluster(cluster) for cluster in clusters}),
records=list({(cluster.parent,) for cluster in clusters}),
table=Clusters,
conn=conn,
batch_size=batch_size,
)

# Insert cluster contains
def cluster_to_cluster_contains(cluster: Cluster) -> dict:
"""Prepares a Cluster for the Contains tables."""
return {
"parent": cluster.parent,
"child": cluster.child,
}

batch_ingest(
records=[cluster_to_cluster_contains(cluster) for cluster in clusters],
records=[(cluster.parent, cluster.child) for cluster in clusters],
table=Contains,
conn=conn,
batch_size=batch_size,
)

# Insert cluster proposed by
def cluster_to_cluster_association(cluster: Cluster, model_hash: bytes) -> dict:
"""Prepares a Cluster for the cluster association table."""
return {
"parent": model_hash,
"child": cluster.parent,
}

# Insert cluster proposed by model
batch_ingest(
records=[
cluster_to_cluster_association(cluster, model_hash)
for cluster in clusters
],
records=[(model_hash, cluster.parent) for cluster in clusters],
table=clusters_association,
conn=conn,
batch_size=batch_size,
Expand Down
4 changes: 3 additions & 1 deletion src/matchbox/server/postgresql/utils/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ def query(

mb_hashes = sql_to_df(hash_query, engine, return_type="arrow")

raw_data = source.to_arrow(fields=fields, pks=mb_hashes["id"].to_pylist())
raw_data = source.to_arrow(
fields=set([source.db_pk] + fields), pks=mb_hashes["id"].to_pylist()
)

# Tablename plus column SQLAlchemy label style
right_key = f"{source.db_schema}_{source.db_table}_{source.db_pk}"
Expand Down
Loading

0 comments on commit 72c7e61

Please sign in to comment.