Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adbc ingestion #62

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"splink>=4.0.5,<4.1.0",
"sqlalchemy>=2.0.35",
"rich>=13.9.4",
"adbc-driver-postgresql>=1.4.0",
]

[project.optional-dependencies]
Expand Down
127 changes: 126 additions & 1 deletion src/matchbox/server/postgresql/utils/db.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import contextlib
import cProfile
import io
import os
import pstats
from itertools import islice
from typing import Any, Callable, Iterable
from datetime import datetime

import pyarrow as pa
import adbc_driver_postgresql.dbapi
from adbc_driver_manager.dbapi import Connection as ADBCConnection
from adbc_driver_manager import DatabaseError as ADBCDatabaseError

from pg_bulk_ingest import Delete, Upsert, ingest
from sqlalchemy import Engine, Index, MetaData, Table, func
from sqlalchemy import Engine, Index, MetaData, Table, func, text
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import DeclarativeMeta, Session
from sqlalchemy.exc import DatabaseError as AlchemyDatabaseError

from matchbox.common.graph import (
ResolutionEdge,
Expand Down Expand Up @@ -162,3 +170,120 @@ def batch_ingest(
upsert=Upsert.IF_PRIMARY_KEY,
delete=Delete.OFF,
)


MB__POSTGRES__PASSWORD = os.environ["MB__POSTGRES__PASSWORD"]
MB__POSTGRES__PORT = os.environ["MB__POSTGRES__PORT"]
MB__POSTGRES__USER = os.environ["MB__POSTGRES__USER"]
MB__POSTGRES__DATABASE = os.environ["MB__POSTGRES__DATABASE"]
MB__POSTGRES__HOST = os.environ["MB__POSTGRES__HOST"]
MB__POSTGRES__SCHEMA = os.environ["MB__POSTGRES__DB_SCHEMA"]


POSTGRESQL_URI = f"postgresql://{MB__POSTGRES__USER}:{MB__POSTGRES__PASSWORD}@{MB__POSTGRES__HOST}:{MB__POSTGRES__PORT}/{MB__POSTGRES__DATABASE}"

def adbc_ingest_data(clusters:pa.Table, contains:pa.Table, probabilities:pa.Table, engine:Engine, resolution_id:int) -> bool:
""" Ingest data from PostgreSQL using pyarrow adbc ingest.
Args: clusters: pa.Table, contains: pa.Table, probabilities: pa.Table, engine: Engine
"""

with engine.connect() as alchemy_conn:
suffix = datetime.now().strftime("%Y%m%d%H%M%S")
if _adbc_insert_data(clusters, contains, probabilities, suffix, alchemy_conn, resolution_id):
return _create_adbc_table_constraints(suffix, alchemy_conn)
else:
return False

def _create_adbc_table_constraints(db_schema:str, sufix:str, conn:Connection) -> bool:
""" Creating primary and secondary keys indexes and constraints.
Args: db_schema: str, the name of the schema
"""
# Cluster
_run_query(f"ALTER TABLE {db_schema}.clusters_{sufix} ADD PRIMARY KEY (cluster_id)", conn)
_run_query(f"""ALTER TABLE {db_schema}.probabilities_{sufix} ADD PRIMARY KEY (resolution, "cluster")""", conn)
_run_query(f"CREATE UNIQUE INDEX cluster_hash_index_{sufix} ON {db_schema}.clusters_{sufix} USING btree (cluster_hash)", conn)
# _run_query(f"CREATE UNIQUE INDEX clusters_adbc_clusters_is_{sufix} ON {db_schema}.clusters_{sufix} USING btree (cluster_id)", conn)
_run_query(f"CREATE INDEX ix_clusters_id_gin_{sufix} ON {db_schema}.clusters_{sufix} USING gin (source_pk)", conn)
_run_query(f"CREATE INDEX ix_mb_clusters_source_pk_{sufix} ON {db_schema}.clusters_{sufix} USING btree (source_pk)", conn)

# Contains
_run_query(f"CREATE UNIQUE INDEX ix_contains_child_parent_{sufix} ON {db_schema}.contains_{sufix} USING btree (child, parent)", conn)
_run_query(f"CREATE UNIQUE INDEX ix_contains_parent_child_{sufix} ON {db_schema}.contains_{sufix} USING btree (parent, child)", conn)

# Foreign keys
_run_query(f"ALTER TABLE {db_schema}.clusters_{sufix} ADD CONSTRAINT clusters_dataset_fkey FOREIGN KEY (dataset) REFERENCES {db_schema}.sources(resolution_id)", conn)
_run_query(f"""ALTER TABLE {db_schema}."contains_{sufix}" ADD CONSTRAINT contains_child_fkey FOREIGN KEY (child) REFERENCES {db_schema}.clusters_{sufix}(cluster_id) ON DELETE CASCADE""", conn)
_run_query(f"""ALTER TABLE {db_schema}."contains_{sufix}" ADD CONSTRAINT contains_parent_fkey FOREIGN KEY (parent) REFERENCES {db_schema}.clusters_{sufix}(cluster_id) ON DELETE CASCADE""", conn)
_run_query(f"""ALTER TABLE {db_schema}.probabilities_{sufix} ADD CONSTRAINT probabilities_cluster_fkey FOREIGN KEY ("cluster") REFERENCES {db_schema}.clusters_{sufix}(cluster_id) ON DELETE CASCADE""", conn)
_run_query(f"ALTER TABLE {db_schema}.probabilities_{sufix} ADD CONSTRAINT probabilities_resolution_fkey FOREIGN KEY (resolution) REFERENCES {db_schema}.resolutions(resolution_id) ON DELETE CASCADE", conn)

_run_queries([
f"""DROP TABLE IF EXISTS {db_schema}.clusters""",
f"""DROP TABLE IF EXISTS {db_schema}.contains""",
f"""DROP TABLE IF EXISTS {db_schema}.probabilities""",

f"""ALTER TABLE {db_schema}.clusters_{sufix} RENAME TO clusters""",
f"""ALTER TABLE {db_schema}.contains_{sufix} RENAME TO contains""",
f"""ALTER TABLE {db_schema}.probabilities_{sufix} RENAME TO probabilities"""
], conn)
return True

def _adbc_insert_data(clusters:pa.Table, contains:pa.Table, probabilities:pa.Table, suffix:str, alchemy_conn:Connection, resolution_id:int) -> bool:
with adbc_driver_postgresql.dbapi.connect(POSTGRESQL_URI) as conn:
try:
_run_query(f"CREATE TABLE clusters_{suffix} AS SELECT * FROM clusters", alchemy_conn)
_save_to_postgresql(
table=clusters,
conn=conn,
schema=MB__POSTGRES__SCHEMA,
table_name=f"clusters_{suffix}",
)
_run_query(f"CREATE TABLE contains_{suffix} AS SELECT * FROM contains", alchemy_conn)
_save_to_postgresql(
table=contains,
conn=conn,
schema=MB__POSTGRES__SCHEMA,
table_name=f"contains_{suffix}",
)
_run_query(f"CREATE TABLE probabilities_{suffix} AS SELECT * FROM probabilities WHERE resolution != {resolution_id}", alchemy_conn)
_save_to_postgresql(
table=probabilities,
conn=conn,
schema=MB__POSTGRES__SCHEMA,
table_name=f"probabilities_{suffix}",
)
conn.commit()
return True
except ADBCConnection as e:
return False
except AlchemyDatabaseError as e:
return False

def _run_query(query: str,conn:Connection) -> None:
conn.execute(text(query))
conn.commit()


def _run_queries(queries: list[str], conn:Connection) -> None:
conn.begin()
for query in queries:
conn.execute(text(query))
conn.commit()

def _save_to_postgresql(
table: pa.Table, conn: ADBCConnection, schema: str, table_name: str
):
"""
Saves a PyArrow Table to PostgreSQL using ADBC.
"""
with conn.cursor() as cursor:
# Convert PyArrow Table to Arrow RecordBatchStream for efficient transfer
batch_reader = pa.RecordBatchReader.from_batches(
table.schema, table.to_batches()
)
cursor.adbc_ingest(
table_name=table_name,
data=batch_reader,
mode="append",
db_schema_name=schema,
)
67 changes: 2 additions & 65 deletions src/matchbox/server/postgresql/utils/insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
Resolutions,
Sources,
)
from matchbox.server.postgresql.utils.db import batch_ingest, hash_to_hex_decode
from matchbox.server.postgresql.utils.db import batch_ingest, hash_to_hex_decode , adbc_ingest_data

logic_logger = logging.getLogger("mb_logic")

Expand Down Expand Up @@ -493,69 +493,6 @@ def insert_results(
resolution=resolution, results=results, engine=engine
)

with Session(engine) as session:
try:
# Clear existing probabilities for this resolution
session.execute(
delete(Probabilities).where(
Probabilities.resolution == resolution.resolution_id
)
)

session.commit()
logic_logger.info(f"[{resolution.name}] Removed old probabilities")

except SQLAlchemyError as e:
session.rollback()
logic_logger.error(
f"[{resolution.name}] Failed to clear old probabilities: {str(e)}"
)
raise

with engine.connect() as conn:
try:
logic_logger.info(
f"[{resolution.name}] Inserting {clusters.shape[0]:,} results objects"
)

batch_ingest(
records=[tuple(c.values()) for c in clusters.to_pylist()],
table=Clusters,
conn=conn,
batch_size=batch_size,
)

logic_logger.info(
f"[{resolution.name}] Successfully inserted {clusters.shape[0]} "
"objects into Clusters table"
)

batch_ingest(
records=[tuple(c.values()) for c in contains.to_pylist()],
table=Contains,
conn=conn,
batch_size=batch_size,
)

logic_logger.info(
f"[{resolution.name}] Successfully inserted {contains.shape[0]} "
"objects into Contains table"
)

batch_ingest(
records=[tuple(c.values()) for c in probabilities.to_pylist()],
table=Probabilities,
conn=conn,
batch_size=batch_size,
)

logic_logger.info(
f"[{resolution.name}] Successfully inserted "
f"{probabilities.shape[0]} objects into Probabilities table"
)

except SQLAlchemyError as e:
logic_logger.error(f"[{resolution.name}] Failed to insert data: {str(e)}")
raise
adbc_ingest_data(clusters=clusters, contains=contains, probabilities=probabilities, engine=engine, resolution_id=resolution.resolution_id)

logic_logger.info(f"[{resolution.name}] Insert operation complete!")
Empty file added test/unit/__init__.py
Empty file.
130 changes: 130 additions & 0 deletions test/unit/test_adbcingest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import unittest
from unittest.mock import patch, MagicMock
import pyarrow as pa
import pandas as pd
from sqlalchemy.engine import Engine, Connection


from matchbox.server.postgresql.utils.db import adbc_ingest_data, _adbc_insert_data, _save_to_postgresql, _run_query, _run_queries


class TestAdbcIngestData(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We test using pytest, though we do use unittest for mocks.


@patch('my_module._create_adbc_table_constraints')
@patch('my_module._adbc_insert_data')
@patch('my_module.datetime')
def test_adbc_ingest_data(self, mock_datetime, mock_adbc_insert_data, mock_create_adbc_table_constraints):
# Mock datetime
mock_datetime.now.return_value.strftime.return_value = "20250101123045"

# Create mock arguments
clusters = pa.Table.from_pandas(pd.DataFrame({"column1": [1, 2], "column2": [3, 4]}))
contains = pa.Table.from_pandas(pd.DataFrame({"column1": [1, 2], "column2": [3, 4]}))
probabilities = pa.Table.from_pandas(pd.DataFrame({"column1": [1, 2], "column2": [3, 4]}))
engine = MagicMock(spec=Engine)
resolution_id = 1

# Mock the engine connection context manager
mock_connection = engine.connect.return_value.__enter__.return_value

# Test when _adbc_insert_data returns True
mock_adbc_insert_data.return_value = True
mock_create_adbc_table_constraints.return_value = True
result = adbc_ingest_data(clusters, contains, probabilities, engine, resolution_id)
self.assertTrue(result)

# Test when _adbc_insert_data returns False
mock_adbc_insert_data.return_value = False
result = adbc_ingest_data(clusters, contains, probabilities, engine, resolution_id)
self.assertFalse(result)


@patch('my_module._save_to_postgresql')
@patch('my_module._run_query')
@patch('my_module.adbc_driver_postgresql.dbapi.connect')
def test_adbc_insert_data(self, mock_connect, mock_run_query, mock_save_to_postgresql):
# Mock the connect method
mock_conn = mock_connect.return_value.__enter__.return_value

# Create mock arguments
clusters = pa.Table.from_pandas(pd.DataFrame({"column1": [1, 2], "column2": [3, 4]}))
contains = pa.Table.from_pandas(pd.DataFrame({"column1": [1, 2], "column2": [3, 4]}))
probabilities = pa.Table.from_pandas(pd.DataFrame({"column1": [1, 2], "column2": [3, 4]}))
suffix = "20250101123045"
alchemy_conn = MagicMock()
resolution_id = 1

# Test when all queries and saves succeed
mock_run_query.side_effect = [None, None, None]
mock_save_to_postgresql.side_effect = [None, None, None]
result = _adbc_insert_data(clusters, contains, probabilities, suffix, alchemy_conn, resolution_id)
self.assertTrue(result)

# Test when a query fails
mock_run_query.side_effect = Exception("Query failed")
result = _adbc_insert_data(clusters, contains, probabilities, suffix, alchemy_conn, resolution_id)
self.assertFalse(result)

# Test when save_to_postgresql fails
mock_run_query.side_effect = [None, None, None]
mock_save_to_postgresql.side_effect = [None, Exception("Save failed"), None]
result = _adbc_insert_data(clusters, contains, probabilities, suffix, alchemy_conn, resolution_id)
self.assertFalse(result)

@patch('my_module.pa.RecordBatchReader.from_batches')
def test_save_to_postgresql(self, mock_from_batches):
# Mock the from_batches method
mock_batch_reader = MagicMock()
mock_from_batches.return_value = mock_batch_reader

# Create mock arguments
table = pa.Table.from_pandas(pd.DataFrame({"column1": [1, 2], "column2": [3, 4]}))
conn = MagicMock()
schema = "test_schema"
table_name = "test_table"

# Mock the cursor context manager
mock_cursor = conn.cursor.return_value.__enter__.return_value

# Call the function
_save_to_postgresql(table, conn, schema, table_name)

# Verify the cursor method was called correctly
mock_cursor.adbc_ingest.assert_called_once_with(
table_name=table_name,
data=mock_batch_reader,
mode="append",
db_schema_name=schema,
)

@patch('my_module.text')
def test_run_query(self, mock_text):
# Create mock arguments
query = "SELECT * FROM test_table"
conn = MagicMock(spec=Connection)

# Call the function
_run_query(query, conn)

# Verify the execute method was called correctly
conn.execute.assert_called_once_with(mock_text(query))
conn.commit.assert_called_once()

@patch('my_module.text')
def test_run_queries(self, mock_text):
# Create mock arguments
queries = ["SELECT * FROM test_table", "DELETE FROM test_table"]
conn = MagicMock(spec=Connection)

# Call the function
_run_queries(queries, conn)

# Verify the execute method was called correctly for each query
conn.begin.assert_called_once()
self.assertEqual(conn.execute.call_count, len(queries))
for query in queries:
conn.execute.assert_any_call(mock_text(query))


if __name__ == '__main__':
unittest.main()
Loading