diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000..38d40684 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,22 @@ +# Pull Request Description + +## Changes Made +- [List the main changes you've made] + +## Reason for Changes +[Explain why you've made these changes] + +## Testing Done +[Describe the testing you've done to validate your changes] + +## Screenshots (if applicable) +[Add screenshots here if your changes include visual elements] + +## Checklist: +- [ ] My code follows the style guidelines of this project +- [ ] I have performed a self-review of my own code +- [ ] I have commented my code, particularly in hard-to-understand areas +- [ ] I have made corresponding changes to the documentation +- [ ] My changes generate no new warnings +- [ ] I have added tests that prove my fix is effective or that my feature works +- [ ] New and existing unit tests pass locally with my changes \ No newline at end of file diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 00000000..1d880c77 --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,23 @@ +name: Unit tests + +jobs: + uv-example: + name: python + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v2 + + - name: Set up Python + run: uv python install + + - name: Set up PostgreSQL + run: | + docker compose up db -d --wait + + - name: Run pytest + run: | + uv python -m pytest \ No newline at end of file diff --git a/Makefile b/Makefile index bea77d94..b9609f02 100644 --- a/Makefile +++ b/Makefile @@ -35,10 +35,7 @@ format: ## Run Python tests test: - @echo ">>> Dropping and recreating the test database" - docker-compose exec db psql -U testuser -c "DROP DATABASE IF EXISTS testdb;" - docker-compose exec db psql -U testuser -c "CREATE DATABASE testdb;" - @echo ">>> Running tests" + docker compose up db -d --wait uv run pytest diff --git a/README.md b/README.md index 51c74003..d35d1aa2 100644 --- a/README.md +++ b/README.md @@ -1,46 +1,17 @@ -# π Company matching framework +# π₯ Matchbox (neΓ© Company Matching Framework) -A match orchestration framework to allow the comparison, validation, and orchestration of the best match methods for the company matching job. +Record matching is a chore. We aim to: -We envisage this forming one of three repos in the Company Matching Framework: +* Make it an iterative, collaborative, measurable problem +* Allow organisations to know they have matching records without having to share the data +* Allow matching pipelines to run iteratively -* `company-matching-framework`, this repo. A Python library for creating data linkage and deduplication pipelines over a shared relational database -* `company-matching-framework-dash`, or https://matching.data.trade.gov.uk/. A dashboard for verifying links and deduplications, and comparing the performance metrics of different approaches. Uses `company-matching-framework` -* `company-matching-framework-pipeline`. The live pipeline of matching and deduping methods, running in production. Uses `company-matching-framework` +## Development -## Coverage +This project is managed by [uv](https://docs.astral.sh/uv/), linted and formated with [ruff](https://docs.astral.sh/ruff/), and tested with [pytest](https://docs.pytest.org/en/stable/). -* [Companies House](https://data.trade.gov.uk/datasets/a777d199-53a4-4d0a-bbbb-1559a86f8c4c#companies-house-company-data) -* [Data Hub companies](https://data.trade.gov.uk/datasets/32918f3e-a727-42e6-8359-9efc61c93aa4#data-hub-companies-master) -* [Export Wins](https://data.trade.gov.uk/datasets/0738396f-d1fd-46f1-a53f-5d8641d032af#export-wins-master-datasets) -* [HMRC UK exporters](https://data.trade.gov.uk/datasets/76fb2db3-ab32-4af8-ae87-d41d36b31265#uk-exporters) +Task running is done with [make](https://www.gnu.org/software/make/). To see all available commands: -## Quickstart - -Clone the repo, then run: - -```bash -. setup.sh -``` - -Create a `.env` with your development schema to write tables into. Copy the sample with `cp .env.sample .env` then fill it in. - -* `SCHEMA` is where any tables the service creates will be written by default - -To set up the database in your specificed schema run: - -```bash -make cmf +```console +make ``` - -## Usage - -See [the aspirational README](references/README_aspitational.md) for how we envisage the finished version of this Python library will be used. - -## Release metrics - -π Coming soon! - --------- - -
Project based on the cookiecutter data science project template.
diff --git a/docker-compose.yml b/docker-compose.yml index 161e3f53..0dcda60b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,8 +1,7 @@ -version: '3.8' - services: db: - image: postgres:13 + image: postgres:14 + restart: always environment: POSTGRES_USER: testuser POSTGRES_PASSWORD: testpassword diff --git a/pyproject.toml b/pyproject.toml index 7eeaa202..6584bc43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "pydantic>=2.9.2", "python-dotenv>=1.0.1", "rustworkx>=0.15.1", - "splink>=4.0.3", + "splink<4", "sqlalchemy>=2.0.35", "tomli>=2.0.1", ] @@ -94,6 +94,8 @@ skip-magic-trailing-comma = false line-ending = "auto" [tool.pytest.ini_options] +testpaths = ["test"] +pythonpath = ["."] addopts = "-s -vv --cov=cmf test/ --log-disable=pg_bulk_ingest" log_cli = false log_cli_level = "INFO" diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/fixtures/__init__.py b/test/fixtures/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/fixtures/data.py b/test/fixtures/data.py index ece88a00..0ffd3d08 100644 --- a/test/fixtures/data.py +++ b/test/fixtures/data.py @@ -6,8 +6,9 @@ import numpy as np import pandas as pd import pytest -import testing.postgresql from dotenv import find_dotenv, load_dotenv +from pandas import DataFrame +from sqlalchemy.engine import Engine import cmf.locations as loc from cmf import process, query @@ -19,11 +20,9 @@ LOGGER = logging.getLogger(__name__) -CMF_POSTGRES = testing.postgresql.PostgresqlFactory(cache_initialized_db=True) - @pytest.fixture(scope="session") -def all_companies(): +def all_companies() -> DataFrame: """ Raw, correct company data. Uses UUID as ID to replicate Data Workspace. 1,000 entries. @@ -36,7 +35,7 @@ def all_companies(): @pytest.fixture(scope="session") -def crn_companies(all_companies): +def crn_companies(all_companies: DataFrame) -> DataFrame: """ Company data split into CRN version. @@ -64,7 +63,7 @@ def crn_companies(all_companies): @pytest.fixture(scope="session") -def duns_companies(all_companies): +def duns_companies(all_companies: DataFrame) -> DataFrame: """ Company data split into DUNS version. @@ -87,7 +86,7 @@ def duns_companies(all_companies): @pytest.fixture(scope="session") -def cdms_companies(all_companies): +def cdms_companies(all_companies: DataFrame) -> DataFrame: """ Company data split into CDMS version. @@ -111,17 +110,15 @@ def cdms_companies(all_companies): @pytest.fixture(scope="function") -def query_clean_crn(db_engine): +def query_clean_crn(db_engine: Engine) -> DataFrame: # Select select_crn = selector( table=f"{os.getenv('SCHEMA')}.crn", fields=["crn", "company_name"], - engine=db_engine[1], + engine=db_engine, ) - crn = query( - selector=select_crn, model=None, return_type="pandas", engine=db_engine[1] - ) + crn = query(selector=select_crn, model=None, return_type="pandas", engine=db_engine) # Clean col_prefix = f"{os.getenv('SCHEMA')}_crn_" @@ -136,16 +133,16 @@ def query_clean_crn(db_engine): @pytest.fixture(scope="function") -def query_clean_duns(db_engine): +def query_clean_duns(db_engine: Engine) -> DataFrame: # Select select_duns = selector( table=f"{os.getenv('SCHEMA')}.duns", fields=["duns", "company_name"], - engine=db_engine[1], + engine=db_engine, ) duns = query( - selector=select_duns, model=None, return_type="pandas", engine=db_engine[1] + selector=select_duns, model=None, return_type="pandas", engine=db_engine ) # Clean @@ -161,16 +158,16 @@ def query_clean_duns(db_engine): @pytest.fixture(scope="function") -def query_clean_cdms(db_engine): +def query_clean_cdms(db_engine: Engine) -> DataFrame: # Select select_cdms = selector( table=f"{os.getenv('SCHEMA')}.cdms", fields=["crn", "cdms"], - engine=db_engine[1], + engine=db_engine, ) cdms = query( - selector=select_cdms, model=None, return_type="pandas", engine=db_engine[1] + selector=select_cdms, model=None, return_type="pandas", engine=db_engine ) # No cleaning needed, see original data @@ -178,19 +175,19 @@ def query_clean_cdms(db_engine): @pytest.fixture(scope="function") -def query_clean_crn_deduped(db_engine): +def query_clean_crn_deduped(db_engine: Engine) -> DataFrame: # Select select_crn = selector( table=f"{os.getenv('SCHEMA')}.crn", fields=["crn", "company_name"], - engine=db_engine[1], + engine=db_engine, ) crn = query( selector=select_crn, model=f"naive_{os.getenv('SCHEMA')}.crn", return_type="pandas", - engine=db_engine[1], + engine=db_engine, ) # Clean @@ -206,19 +203,19 @@ def query_clean_crn_deduped(db_engine): @pytest.fixture(scope="function") -def query_clean_duns_deduped(db_engine): +def query_clean_duns_deduped(db_engine: Engine) -> DataFrame: # Select select_duns = selector( table=f"{os.getenv('SCHEMA')}.duns", fields=["duns", "company_name"], - engine=db_engine[1], + engine=db_engine, ) duns = query( selector=select_duns, model=f"naive_{os.getenv('SCHEMA')}.duns", return_type="pandas", - engine=db_engine[1], + engine=db_engine, ) # Clean @@ -234,19 +231,19 @@ def query_clean_duns_deduped(db_engine): @pytest.fixture(scope="function") -def query_clean_cdms_deduped(db_engine): +def query_clean_cdms_deduped(db_engine: Engine) -> DataFrame: # Select select_cdms = selector( table=f"{os.getenv('SCHEMA')}.cdms", fields=["crn", "cdms"], - engine=db_engine[1], + engine=db_engine, ) cdms = query( selector=select_cdms, model=f"naive_{os.getenv('SCHEMA')}.cdms", return_type="pandas", - engine=db_engine[1], + engine=db_engine, ) # No cleaning needed, see original data diff --git a/test/fixtures/db.py b/test/fixtures/db.py index 2b6ab62e..2e7e1cdf 100644 --- a/test/fixtures/db.py +++ b/test/fixtures/db.py @@ -2,19 +2,17 @@ import logging import os import random -from test.fixtures.models import DedupeTestParams, LinkTestParams, ModelTestParams -from typing import Generator, Callable -from pandas import DataFrame +from typing import Callable, Generator import pytest -import docker -from _pytest.fixtures import FixtureFunction, FixtureRequest +from _pytest.fixtures import FixtureRequest from dotenv import find_dotenv, load_dotenv +from pandas import DataFrame from sqlalchemy import MetaData, create_engine, inspect, text from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.engine import Engine from sqlalchemy.orm import Session from sqlalchemy.schema import CreateSchema -from sqlalchemy.engine import Engine from cmf import make_deduper, make_linker, to_clusters from cmf.admin import add_dataset @@ -34,30 +32,14 @@ clusters_association, ) +from .models import DedupeTestParams, LinkTestParams, ModelTestParams + dotenv_path = find_dotenv() load_dotenv(dotenv_path) LOGGER = logging.getLogger(__name__) -@pytest.fixture(scope="session") -def postgresql() -> docker.models.containers.Container: - client = docker.from_env() - container = client.containers.run( - "postgres:latest", - environment={ - "POSTGRES_USER": "test", - "POSTGRES_PASSWORD": "test", - "POSTGRES_DB": "test_db", - }, - ports={"5432/tcp": 5432}, - detach=True, - ) - yield container - container.stop() - container.remove() - - @pytest.fixture def db_clear_all() -> Callable[[Engine], None]: """ @@ -68,8 +50,8 @@ def db_clear_all() -> Callable[[Engine], None]: def _db_clear_all(db_engine: Engine) -> None: db_metadata = MetaData(schema=os.getenv("SCHEMA")) - db_metadata.reflect(bind=db_engine[1]) - with Session(db_engine[1]) as session: + db_metadata.reflect(bind=db_engine) + with Session(db_engine) as session: for table in reversed(db_metadata.sorted_tables): LOGGER.info(f"{table}") session.execute(table.delete()) @@ -87,7 +69,7 @@ def db_clear_data() -> Callable[[Engine], None]: """ def _db_clear_data(db_engine: Engine) -> None: - with Session(db_engine[1]) as session: + with Session(db_engine) as session: session.query(SourceData).delete() session.query(SourceDataset).delete() session.commit() @@ -106,7 +88,7 @@ def db_clear_models() -> Callable[[Engine], None]: """ def _db_clear_models(db_engine: Engine) -> None: - with Session(db_engine[1]) as session: + with Session(db_engine) as session: session.query(LinkProbabilities).delete() session.query(Links).delete() session.query(DDupeProbabilities).delete() @@ -135,7 +117,7 @@ def db_add_data( """ def _db_add_data(db_engine: Engine) -> None: - with db_engine[1].connect() as conn: + with db_engine.connect() as conn: # Insert data crn_companies.to_sql( "crn", @@ -196,7 +178,7 @@ def db_add_models() -> Callable[[Engine], None]: """ def _db_add_models(db_engine: Engine) -> None: - with Session(db_engine[1]) as session: + with Session(db_engine) as session: # Two Dedupers and two Linkers dd_m1 = Models( sha1=hashlib.sha1("dd_m1".encode()).digest(), @@ -345,8 +327,8 @@ def _db_add_dedupe_models_and_data( df, results=deduped, key="data_sha1", threshold=0 ) - deduped.to_cmf(engine=db_engine[1]) - clustered.to_cmf(engine=db_engine[1]) + deduped.to_cmf(engine=db_engine) + clustered.to_cmf(engine=db_engine) return _db_add_dedupe_models_and_data @@ -423,33 +405,26 @@ def _db_add_link_models_and_data( df_l, df_r, results=linked, key="cluster_sha1", threshold=0 ) - linked.to_cmf(engine=db_engine[1]) - clustered.to_cmf(engine=db_engine[1]) + linked.to_cmf(engine=db_engine) + clustered.to_cmf(engine=db_engine) return _db_add_link_models_and_data @pytest.fixture(scope="session") def db_engine( - postgresql: docker.models.containers.Container, - db_add_data: Callable[[Engine], None], - db_add_models: Callable[[Engine], None], + db_add_data: Callable[[Engine], None], db_add_models: Callable[[Engine], None] ) -> Generator[Engine, None, None]: """ - Yield engine to mock in-memory database. + Yield engine to Docker container database. """ load_dotenv(find_dotenv()) engine = create_engine( - url="postgresql://test:test@localhost:5432/test_db", + url="postgresql://testuser:testpassword@localhost:5432/testdb", connect_args={"sslmode": "disable", "client_encoding": "utf8"}, ) with engine.connect() as conn: - # Install relevant extensions - conn.execute(text('create extension "uuid-ossp";')) - conn.execute(text("create extension pgcrypto;")) - conn.commit() - # Create CMF schema if not inspect(conn).has_schema(os.getenv("SCHEMA")): conn.execute(CreateSchema(os.getenv("SCHEMA"))) @@ -472,11 +447,19 @@ def db_engine( @pytest.fixture(scope="session", autouse=True) -def cleanup(postgresql, request): - """Cleanup the PostgreSQL database when we're done.""" +def cleanup(db_engine, request): + """Cleanup the PostgreSQL database by dropping all tables when we're done.""" def teardown(): - postgresql.stop() - postgresql.remove() + with db_engine.connect() as conn: + inspector = inspect(conn) + for table_name in inspector.get_table_names(schema=os.getenv("SCHEMA")): + conn.execute( + text( + f'DROP TABLE IF EXISTS "{os.getenv("SCHEMA")}".' + f'"{table_name}" CASCADE;' + ) + ) + conn.commit() request.addfinalizer(teardown) diff --git a/test/test_db.py b/test/test_db.py index 17352808..08d420b5 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -1,12 +1,6 @@ import itertools import logging import os -from test.fixtures.models import ( - dedupe_data_test_params, - dedupe_model_test_params, - link_data_test_params, - link_model_test_params, -) from dotenv import find_dotenv, load_dotenv from sqlalchemy import MetaData, Table, delete, insert, inspect, text @@ -25,6 +19,13 @@ clusters_association, ) +from .fixtures.models import ( + dedupe_data_test_params, + dedupe_model_test_params, + link_data_test_params, + link_model_test_params, +) + dotenv_path = find_dotenv() load_dotenv(dotenv_path) @@ -35,7 +36,7 @@ def test_database(db_engine): """ Test the database contains all the tables we expect. """ - tables = set(inspect(db_engine[1]).get_table_names(schema=os.getenv("SCHEMA"))) + tables = set(inspect(db_engine).get_table_names(schema=os.getenv("SCHEMA"))) to_check = { "crn", "duns", @@ -59,7 +60,7 @@ def test_database(db_engine): assert tables == to_check - with Session(db_engine[1]) as session: + with Session(db_engine) as session: server_encoding = session.execute(text("show server_encoding;")).scalar() client_encoding = session.execute(text("show client_encoding;")).scalar() @@ -70,7 +71,7 @@ def test_add_data(db_engine): """ Test all datasets were inserted. """ - with Session(db_engine[1]) as session: + with Session(db_engine) as session: inserted_tables = session.query(SourceDataset.db_table).all() inserted_tables = {t[0] for t in inserted_tables} expected_tables = {"crn", "duns", "cdms"} @@ -83,7 +84,7 @@ def test_inserted_data(db_engine, crn_companies, duns_companies, cdms_companies) Test all data was inserted. Note we drop duplicates because they're rolled up to arrays. """ - with Session(db_engine[1]) as session: + with Session(db_engine) as session: inserted_rows = session.query(SourceData).count() raw_rows = ( crn_companies.drop(columns=["id"]).drop_duplicates().shape[0] @@ -111,7 +112,7 @@ def test_insert_data(db_engine, crn_companies, duns_companies, cdms_companies): {"id": 3004, "company_name": "Eidel", "crn": "01HJ0TY5CRET0YPB0WF2R0DFEB"}, {"id": 3005, "company_name": "Zoozzy", "crn": "01HJ0TY5CRHDX0NX5RSBJWSSKF"}, ] - with Session(db_engine[1]) as session: + with Session(db_engine) as session: # Reflect the table and insert the data db_metadata = MetaData(schema=os.getenv("SCHEMA")) crn_table = Table( @@ -130,7 +131,7 @@ def test_insert_data(db_engine, crn_companies, duns_companies, cdms_companies): "table": "crn", "id": "id", }, - db_engine[1], + db_engine, ) # Test SourceData now contains 5 more rows @@ -154,7 +155,7 @@ def test_model_cluster_association(db_engine, db_clear_models, db_add_models): db_add_models(db_engine) # Model has six clusters - with Session(db_engine[1]) as session: + with Session(db_engine) as session: m = session.query(Models).filter_by(name="l_m1").first() clusters_in_db = session.query(Clusters).count() creates_in_db = session.query(clusters_association).count() @@ -175,7 +176,7 @@ def test_model_cluster_association(db_engine, db_clear_models, db_add_models): session.commit() # Model creates no clusters but clusters still exist - with Session(db_engine[1]) as session: + with Session(db_engine) as session: m = session.query(Models).filter_by(name="l_m1").first() clusters_in_db = session.query(Clusters).count() creates_in_db = session.query(clusters_association).count() @@ -198,7 +199,7 @@ def test_model_ddupe_association(db_engine, db_clear_models, db_add_models): db_add_models(db_engine) # Model proposes deduplications across six data nodes, 6**2 - with Session(db_engine[1]) as session: + with Session(db_engine) as session: m = session.query(Models).filter_by(name="dd_m1").first() ddupes_in_db = session.query(Dedupes).count() ddupe_probs_in_db = session.query(DDupeProbabilities).count() @@ -219,7 +220,7 @@ def test_model_ddupe_association(db_engine, db_clear_models, db_add_models): session.commit() # Model proposes no deduplications but dedupes still exist - with Session(db_engine[1]) as session: + with Session(db_engine) as session: m = session.query(Models).filter_by(name="dd_m1").first() ddupes_in_db = session.query(Dedupes).count() ddupe_probs_in_db = session.query(DDupeProbabilities).count() @@ -241,7 +242,7 @@ def test_model_link_association(db_engine, db_clear_models, db_add_models): db_add_models(db_engine) # Model proposes links across six cluster nodes, 6**2 - with Session(db_engine[1]) as session: + with Session(db_engine) as session: m = session.query(Models).filter_by(name="l_m1").first() links_in_db = session.query(Links).count() link_probs_in_db = session.query(LinkProbabilities).count() @@ -262,7 +263,7 @@ def test_model_link_association(db_engine, db_clear_models, db_add_models): session.commit() # Model proposes no linkings but links still exist - with Session(db_engine[1]) as session: + with Session(db_engine) as session: m = session.query(Models).filter_by(name="l_m1").first() links_in_db = session.query(Links).count() link_probs_in_db = session.query(LinkProbabilities).count() @@ -278,14 +279,14 @@ def test_db_delete( """ Test that the clearing test functions works. """ - with Session(db_engine[1]) as session: + with Session(db_engine) as session: data_before = session.query(SourceData).count() models_before = session.query(Models).count() db_clear_models(db_engine) db_clear_data(db_engine) - with Session(db_engine[1]) as session: + with Session(db_engine) as session: data_after = session.query(SourceData).count() models_after = session.query(Models).count() @@ -315,7 +316,7 @@ def test_add_dedupers_and_data( test_param.source: test_param for test_param in dedupe_data_test_params } - with Session(db_engine[1]) as session: + with Session(db_engine) as session: model_list = session.query(Models).all() assert len(model_list) == len(dedupe_data_test_params) @@ -361,7 +362,7 @@ def test_add_linkers_and_data( request=request, ) - with Session(db_engine[1]) as session: + with Session(db_engine) as session: model_list = session.query(Models).filter(Models.deduplicates == None).all() # NoQA E711 assert len(model_list) == len(link_data_test_params) @@ -371,7 +372,7 @@ def test_add_linkers_and_data( ): linker_name = f"{fx_linker.name}_{fx_data.source_l}_{fx_data.source_r}" - with Session(db_engine[1]) as session: + with Session(db_engine) as session: model = session.query(Models).filter(Models.name == linker_name).first() assert session.scalar(model.links_count()) == fx_data.tgt_prob_n diff --git a/test/test_dedupers.py b/test/test_dedupers.py index 43075fd1..83401fc0 100644 --- a/test/test_dedupers.py +++ b/test/test_dedupers.py @@ -1,5 +1,3 @@ -from test.fixtures.models import dedupe_data_test_params, dedupe_model_test_params - import pytest from pandas import DataFrame from sqlalchemy.orm import Session @@ -7,6 +5,8 @@ from cmf import make_deduper, to_clusters from cmf.data import Models +from .fixtures.models import dedupe_data_test_params, dedupe_model_test_params + @pytest.mark.parametrize("fx_data", dedupe_data_test_params) @pytest.mark.parametrize("fx_deduper", dedupe_model_test_params) @@ -84,9 +84,9 @@ def test_dedupers( # 3. Deduplicated probabilities are inserted correctly - deduped.to_cmf(engine=db_engine[1]) + deduped.to_cmf(engine=db_engine) - with Session(db_engine[1]) as session: + with Session(db_engine) as session: model = session.query(Models).filter_by(name=deduper_name).first() assert session.scalar(model.dedupes_count()) == fx_data.tgt_prob_n @@ -126,9 +126,9 @@ def test_dedupers( # 5. Resolved clusters are inserted correctly - clusters_all.to_cmf(engine=db_engine[1]) + clusters_all.to_cmf(engine=db_engine) - with Session(db_engine[1]) as session: + with Session(db_engine) as session: model = session.query(Models).filter_by(name=deduper_name).first() assert session.scalar(model.creates_count()) == fx_data.unique_n diff --git a/test/test_helpers.py b/test/test_helpers.py index 7c098300..70e855f4 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -1,11 +1,5 @@ import logging import os -from test.fixtures.models import ( - dedupe_data_test_params, - dedupe_model_test_params, - link_data_test_params, - link_model_test_params, -) from dotenv import find_dotenv, load_dotenv from matplotlib.figure import Figure @@ -33,6 +27,13 @@ selectors, ) +from .fixtures.models import ( + dedupe_data_test_params, + dedupe_model_test_params, + link_data_test_params, + link_model_test_params, +) + dotenv_path = find_dotenv() load_dotenv(dotenv_path) @@ -41,10 +42,10 @@ def test_selectors(db_engine): select_crn = selector( - table=f"{os.getenv('SCHEMA')}.crn", fields=["id", "crn"], engine=db_engine[1] + table=f"{os.getenv('SCHEMA')}.crn", fields=["id", "crn"], engine=db_engine ) select_duns = selector( - table=f"{os.getenv('SCHEMA')}.duns", fields=["id", "duns"], engine=db_engine[1] + table=f"{os.getenv('SCHEMA')}.duns", fields=["id", "duns"], engine=db_engine ) select_crn_duns = selectors(select_crn, select_duns) @@ -54,14 +55,14 @@ def test_selectors(db_engine): def test_single_table_no_model_query(db_engine): """Tests query() on a single table. No point of truth to derive clusters""" select_crn = selector( - table=f"{os.getenv('SCHEMA')}.crn", fields=["id", "crn"], engine=db_engine[1] + table=f"{os.getenv('SCHEMA')}.crn", fields=["id", "crn"], engine=db_engine ) df_crn_sample = query( selector=select_crn, model=None, return_type="pandas", - engine=db_engine[1], + engine=db_engine, limit=10, ) @@ -69,7 +70,7 @@ def test_single_table_no_model_query(db_engine): assert df_crn_sample.shape[0] == 10 df_crn_full = query( - selector=select_crn, model=None, return_type="pandas", engine=db_engine[1] + selector=select_crn, model=None, return_type="pandas", engine=db_engine ) assert df_crn_full.shape[0] == 3000 @@ -83,15 +84,15 @@ def test_single_table_no_model_query(db_engine): def test_multi_table_no_model_query(db_engine): """Tests query() on multiple tables. No point of truth to derive clusters""" select_crn = selector( - table=f"{os.getenv('SCHEMA')}.crn", fields=["id", "crn"], engine=db_engine[1] + table=f"{os.getenv('SCHEMA')}.crn", fields=["id", "crn"], engine=db_engine ) select_duns = selector( - table=f"{os.getenv('SCHEMA')}.duns", fields=["id", "duns"], engine=db_engine[1] + table=f"{os.getenv('SCHEMA')}.duns", fields=["id", "duns"], engine=db_engine ) select_crn_duns = selectors(select_crn, select_duns) df_crn_duns_full = query( - selector=select_crn_duns, model=None, return_type="pandas", engine=db_engine[1] + selector=select_crn_duns, model=None, return_type="pandas", engine=db_engine ) assert df_crn_duns_full.shape[0] == 3500 @@ -136,14 +137,14 @@ def test_single_table_with_model_query( select_crn = selector( table=f"{os.getenv('SCHEMA')}.crn", fields=["crn", "company_name"], - engine=db_engine[1], + engine=db_engine, ) crn = query( selector=select_crn, model=f"naive_{os.getenv('SCHEMA')}.crn", return_type="pandas", - engine=db_engine[1], + engine=db_engine, ) assert isinstance(crn, DataFrame) @@ -188,10 +189,10 @@ def test_multi_table_with_model_query( ) select_crn = selector( - table=f"{os.getenv('SCHEMA')}.crn", fields=["crn"], engine=db_engine[1] + table=f"{os.getenv('SCHEMA')}.crn", fields=["crn"], engine=db_engine ) select_duns = selector( - table=f"{os.getenv('SCHEMA')}.duns", fields=["duns"], engine=db_engine[1] + table=f"{os.getenv('SCHEMA')}.duns", fields=["duns"], engine=db_engine ) select_crn_duns = selectors(select_crn, select_duns) @@ -199,7 +200,7 @@ def test_multi_table_with_model_query( selector=select_crn_duns, model=linker_name, return_type="pandas", - engine=db_engine[1], + engine=db_engine, ) assert isinstance(crn_duns, DataFrame) @@ -228,11 +229,11 @@ def test_process(db_engine): select_name = selector( table=f"{os.getenv('SCHEMA')}.crn", fields=["crn", "company_name"], - engine=db_engine[1], + engine=db_engine, ) df_name = query( - selector=select_name, model=None, return_type="pandas", engine=db_engine[1] + selector=select_name, model=None, return_type="pandas", engine=db_engine ) cleaner_name = cleaner( @@ -261,7 +262,7 @@ def test_comparisons(): def test_draw_model_tree(db_engine): - plt = draw_model_tree(db_engine[1]) + plt = draw_model_tree(db_engine) assert isinstance(plt, Figure) @@ -298,7 +299,7 @@ def test_model_deletion( deduper_to_delete = f"naive_{os.getenv('SCHEMA')}.crn" total_models = len(dedupe_data_test_params) + len(link_data_test_params) - with Session(db_engine[1]) as session: + with Session(db_engine) as session: model_list_pre_delete = session.query(Models).all() assert len(model_list_pre_delete) == total_models @@ -317,9 +318,9 @@ def test_model_deletion( assert link_prob_count_pre_delete > 0 # Perform deletion - delete_model(deduper_to_delete, engine=db_engine[1], certain=True) + delete_model(deduper_to_delete, engine=db_engine, certain=True) - with Session(db_engine[1]) as session: + with Session(db_engine) as session: model_list_post_delete = session.query(Models).all() # Deletes deduper and parent linkers: 3 models gone assert len(model_list_post_delete) == len(model_list_pre_delete) - 3 diff --git a/test/test_linkers.py b/test/test_linkers.py index 98779e54..6df08616 100644 --- a/test/test_linkers.py +++ b/test/test_linkers.py @@ -1,10 +1,3 @@ -from test.fixtures.models import ( - dedupe_data_test_params, - dedupe_model_test_params, - link_data_test_params, - link_model_test_params, -) - import pytest from pandas import DataFrame from sqlalchemy.orm import Session @@ -12,6 +5,13 @@ from cmf import make_linker, to_clusters from cmf.data import Models +from .fixtures.models import ( + dedupe_data_test_params, + dedupe_model_test_params, + link_data_test_params, + link_model_test_params, +) + @pytest.mark.parametrize("fx_data", link_data_test_params) @pytest.mark.parametrize("fx_linker", link_model_test_params) @@ -116,9 +116,9 @@ def test_linkers( # 3. Linked probabilities are inserted correctly - linked.to_cmf(engine=db_engine[1]) + linked.to_cmf(engine=db_engine) - with Session(db_engine[1]) as session: + with Session(db_engine) as session: model = session.query(Models).filter_by(name=linker_name).first() assert session.scalar(model.links_count()) == fx_data.tgt_prob_n @@ -214,9 +214,9 @@ def unique_non_null(s): # 5. Resolved clusters are inserted correctly - clusters_all.to_cmf(engine=db_engine[1]) + clusters_all.to_cmf(engine=db_engine) - with Session(db_engine[1]) as session: + with Session(db_engine) as session: model = session.query(Models).filter_by(name=linker_name).first() assert session.scalar(model.creates_count()) == fx_data.unique_n diff --git a/uv.lock b/uv.lock index c4b123ce..98fb7a01 100644 --- a/uv.lock +++ b/uv.lock @@ -241,7 +241,7 @@ requires-dist = [ { name = "pydantic", specifier = ">=2.9.2" }, { name = "python-dotenv", specifier = ">=1.0.1" }, { name = "rustworkx", specifier = ">=0.15.1" }, - { name = "splink", specifier = ">=4.0.3" }, + { name = "splink", specifier = "<4" }, { name = "sqlalchemy", specifier = ">=2.0.35" }, { name = "tomli", specifier = ">=2.0.1" }, ] @@ -1022,6 +1022,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3c/c3/a21d017f41c4d7603c0aa895ad781ea24fa7c9cc412056aee119eb326883/pg_force_execute-0.0.11-py3-none-any.whl", hash = "sha256:250587c0f4c51a2997454442a0f39c2ab4113dc70ebae2015f1556f080595e4a", size = 4492 }, ] +[[package]] +name = "phonetics" +version = "1.0.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/67/a5/d1b6dbcbb05477aa5f0c5e73a7d68c6d23ab098af4461072f00999ed573a/phonetics-1.0.5.tar.gz", hash = "sha256:16263948c82fce1e257964b2ab4adc953f995e0fa7e2e60e6ba336d77a7235ba", size = 8848 } + [[package]] name = "pillow" version = "10.4.0" @@ -1596,7 +1602,7 @@ wheels = [ [[package]] name = "splink" -version = "4.0.3" +version = "3.9.15" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "altair" }, @@ -1604,13 +1610,13 @@ dependencies = [ { name = "igraph" }, { name = "jinja2" }, { name = "jsonschema" }, - { name = "numpy" }, { name = "pandas" }, + { name = "phonetics" }, { name = "sqlglot" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/3c/31/95e23817eaa146a03ef3b7be790157faec844452be9b282fdd1c7c463eb6/splink-4.0.3.tar.gz", hash = "sha256:ee774fcb9f51f6a2ab10ff6ead57be2c5ee1f8b9d4140ee261786359f87e5d4a", size = 3649837 } +sdist = { url = "https://files.pythonhosted.org/packages/c9/fe/457c18d9f54e6b34ddd4a908ee90e61b8287006560d36952afec7cae45d9/splink-3.9.15.tar.gz", hash = "sha256:d52a4f2e48567b502621924cbd909f88c6cb88b32442d575deb4c16bbbb2ccad", size = 3655727 } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/2b/0ee804bbdccbb47f1a1694d2878a478e6f192b78e47fbbfac2c05787b23b/splink-4.0.3-py3-none-any.whl", hash = "sha256:d990c9552a8601c0d61232d428810730698e005e4ebfef16eaf5ce0d581050d2", size = 3711863 }, + { url = "https://files.pythonhosted.org/packages/96/8a/99cf732fb1a6aac4535e0c4a641c1159a5f1d2fa9bce6452c52f078c7ba5/splink-3.9.15-py3-none-any.whl", hash = "sha256:1b8f557743e633c785fa6da4030821d0cd1ccf03336d662db59f809956f4ec87", size = 3713845 }, ] [[package]]