Skip to content

Commit

Permalink
Exclude fields (#20)
Browse files Browse the repository at this point in the history
* add exclude fields parameter
  • Loading branch information
eloyfelix authored Jan 17, 2025
1 parent f45056f commit 11a6adc
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 76 deletions.
109 changes: 64 additions & 45 deletions cbl_migrator/migrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,6 @@
from .logs import logger


def insert_rows_in_dest(table, data, d_eng):
"""
Inserts the list of rows into the destination table.
"""
with d_eng.begin() as conn:
conn.execute(
table.insert(),
[
dict(zip(res_keys, row))
for res_keys, row in zip([data.keys()] * len(data.all()), data.all())
],
)


def chunked_copy_single_pk(table, pk, last_id, chunk_size, o_eng, d_eng):
"""
Copies table data in chunks, assuming a single PK column.
Expand Down Expand Up @@ -133,18 +119,27 @@ class DbMigrator:
o_conn_string (str): Origin DB connection string.
d_conn_string (str): Destination DB connection string.
exclude (list[str]): List of tables to exclude from migration.
exclude_fields (list[str]): List of fields to exclude in format 'table.field'.
n_cores (int): Number of processes used for data copying.
Methods:
migrate: Executes the migration pipeline.
"""

def __init__(self, o_conn_string, d_conn_string, exclude=None, n_workers=4):
if exclude is None:
exclude = []
def __init__(
self,
o_conn_string,
d_conn_string,
exclude_tables=None,
exclude_fields=None,
n_workers=4,
):
if exclude_tables is None:
exclude_tables = []
if exclude_fields is None:
exclude_fields = []
self.o_eng_conn = o_conn_string
self.d_eng_conn = d_conn_string
self.n_cores = n_workers
self.exclude_fields = {f.split(".")[0]: f.split(".")[1] for f in exclude_fields}
print(self.exclude_fields)

o_eng = create_engine(self.o_eng_conn)
metadata = MetaData()
Expand All @@ -154,7 +149,7 @@ def __init__(self, o_conn_string, d_conn_string, exclude=None, n_workers=4):
for t, table in metadata.tables.items()
if not list(table.primary_key.columns)
]
self.exclude = exclude + no_pk
self.exclude_tables = exclude_tables + no_pk

def __fix_column_type(self, col, o_eng, d_eng):
"""
Expand Down Expand Up @@ -190,7 +185,9 @@ def __copy_schema(self):
insp = inspect(o_eng)

new_metadata_tables = {}
tables = filter(lambda x: x[0] not in self.exclude, metadata.tables.items())
tables = filter(
lambda x: x[0] not in self.exclude_tables, metadata.tables.items()
)
for table_name, table in tables:
# Keep only PK constraints unless it's SQLite
keep_constraints = [
Expand Down Expand Up @@ -223,10 +220,12 @@ def __copy_schema(self):
table.indexes = set()

new_metadata_cols = ColumnCollection()
excluded_fields = self.exclude_fields.get(table_name, [])
for col in table._columns:
col = self.__fix_column_type(col, o_eng.name, d_eng.name)
col.autoincrement = False
new_metadata_cols.add(col)
if col.name not in excluded_fields:
col = self.__fix_column_type(col, o_eng.name, d_eng.name)
col.autoincrement = False
new_metadata_cols.add(col)
table.columns = new_metadata_cols.as_readonly()
new_metadata_tables[table_name] = table

Expand All @@ -246,10 +245,14 @@ def validate_migration(self):
d_metadata.reflect(d_eng)

o_tables = {
t: tbl for t, tbl in o_metadata.tables.items() if t not in self.exclude
t: tbl
for t, tbl in o_metadata.tables.items()
if t not in self.exclude_tables
}
d_tables = {
t: tbl for t, tbl in d_metadata.tables.items() if t not in self.exclude
t: tbl
for t, tbl in d_metadata.tables.items()
if t not in self.exclude_tables
}

if set(o_tables.keys()) != set(d_tables.keys()):
Expand All @@ -272,41 +275,51 @@ def validate_migration(self):

def __copy_constraints(self):
"""
Migrates constraints to the destination DB (UK, CK, FK).
Migrates constraints to the destination DB (UK, CK, FK), skipping those
that involve excluded fields.
"""
o_eng = create_engine(self.o_eng_conn)
d_eng = create_engine(self.d_eng_conn)
metadata = MetaData()
metadata.reflect(o_eng)
insp = inspect(o_eng)

tables = filter(lambda x: x[0] not in self.exclude, metadata.tables.items())
tables = filter(
lambda x: x[0] not in self.exclude_tables, metadata.tables.items()
)
for table_name, table in tables:
constraints_to_keep = []
excluded_fields = self.exclude_fields.get(table_name, [])

# Unique constraints
# Unique constraints - skip if any column is excluded
uks = insp.get_unique_constraints(table_name)
for uk in uks:
uk_cols = [c for c in table._columns if c.name in uk["column_names"]]
uuk = UniqueConstraint(*uk_cols, name=uk["name"])
uuk._set_parent(table)
constraints_to_keep.append(uuk)
if not any(col in excluded_fields for col in uk["column_names"]):
uk_cols = [
c for c in table._columns if c.name in uk["column_names"]
]
uuk = UniqueConstraint(*uk_cols, name=uk["name"])
uuk._set_parent(table)
constraints_to_keep.append(uuk)

# Check constraints
# Check constraints - skip if any column is excluded
ccs = [
cons for cons in table.constraints if isinstance(cons, CheckConstraint)
]
for cc in ccs:
cc.sqltext = TextClause(str(cc.sqltext).replace('"', ""))
constraints_to_keep.append(cc)
if not any(col in str(cc.sqltext) for col in excluded_fields):
cc.sqltext = TextClause(str(cc.sqltext).replace('"', ""))
constraints_to_keep.append(cc)

# Foreign keys
# Foreign keys - skip if any column is excluded
fks = [
cons
for cons in table.constraints
if isinstance(cons, ForeignKeyConstraint)
]
constraints_to_keep.extend(fks)
for fk in fks:
if not any(col.name in excluded_fields for col in fk.columns):
constraints_to_keep.append(fk)

# Create constraints
for cons in constraints_to_keep:
Expand All @@ -319,23 +332,29 @@ def __copy_constraints(self):
def __copy_indexes(self):
"""
Creates indexes in the destination DB, skipping those
already defined via unique or primary constraints.
already defined via unique or primary constraints and
those involving excluded fields.
"""
o_eng = create_engine(self.o_eng_conn)
d_eng = create_engine(self.d_eng_conn)
metadata = MetaData()
metadata.reflect(o_eng)
insp = inspect(o_eng)

tables = filter(lambda x: x[0] not in self.exclude, metadata.tables.items())
tables = filter(
lambda x: x[0] not in self.exclude_tables, metadata.tables.items()
)
for table_name, table in tables:
excluded_fields = self.exclude_fields.get(table_name, [])
uks = insp.get_unique_constraints(table_name)
pk = insp.get_pk_constraint(table_name)

indexes_to_keep = [
idx
for idx in table.indexes
if idx.name not in [u["name"] for u in uks] and idx.name != pk["name"]
if idx.name not in [u["name"] for u in uks]
and idx.name != pk["name"]
and not any(col.name in excluded_fields for col in idx.columns)
]
for index in indexes_to_keep:
try:
Expand Down Expand Up @@ -385,12 +404,12 @@ def migrate(
# Fill tables with data
if copy_data:
metadata = MetaData()
metadata.reflect(o_eng)
insp = inspect(o_eng)
metadata.reflect(d_eng)
insp = inspect(d_eng)
table_names = [
t
for t, _ in insp.get_sorted_table_and_fkc_names()
if t and t not in self.exclude
if t and t not in self.exclude_tables
]
tables = [metadata.tables[t] for t in table_names]

Expand Down
87 changes: 57 additions & 30 deletions cbl_migrator/test/test_migration.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from sqlalchemy import MetaData, create_engine, inspect, insert
from .schema import Base, Compound, CompoundStructure, CompoundProperties
from .. import DbMigrator
import unittest
import pytest
import random
import os


molblock = """
SciTegic01111613442D
Expand Down Expand Up @@ -84,15 +85,17 @@
]


class TestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestCase, self).__init__(*args, **kwargs)

class TestMigration:
@pytest.fixture(autouse=True)
def setup(self):
self.origin = "sqlite:///origin.db"
self.dest = "sqlite:///dest.db"
yield
# remove files after tests run
os.remove("origin.db")
os.remove("dest.db")

def __gen_test_data(self):
# create schema
engine = create_engine(self.origin)
Base.metadata.create_all(bind=engine)

Expand Down Expand Up @@ -148,58 +151,82 @@ def __get_tables_insp(self):

return o_tables, d_tables, o_insp, d_insp

def test_a_migration(self):
def test_01_verify_migration(self):
self.__gen_test_data()
migrator = DbMigrator(self.origin, self.dest)
self.assertTrue(migrator.migrate(chunk_size=10))
assert migrator.migrate(chunk_size=10) is True

def test_b_uks(self):
def test_02_verify_unique_constraints(self):
o_tables, _, o_insp, d_insp = self.__get_tables_insp()

for table_name, _ in o_tables.items():
o_uks = o_insp.get_unique_constraints(table_name)
d_uks = d_insp.get_unique_constraints(table_name)
self.assertEqual(o_uks, d_uks)
assert o_uks == d_uks

def test_c_idxs(self):
def test_03_verify_indexes(self):
o_tables, d_tables, _, _ = self.__get_tables_insp()

for table_name, table in o_tables.items():
self.assertEqual(
[index.name for index in table.indexes],
[index.name for index in d_tables[table_name].indexes],
)
assert [index.name for index in table.indexes] == [
index.name for index in d_tables[table_name].indexes
]

def test_d_pks(self):
def test_04_verify_primary_keys(self):
o_tables, d_tables, _, _ = self.__get_tables_insp()

for table_name, table in o_tables.items():
self.assertEqual(
[col.name for col in table.primary_key.columns],
[col.name for col in d_tables[table_name].primary_key.columns],
)
assert [col.name for col in table.primary_key.columns] == [
col.name for col in d_tables[table_name].primary_key.columns
]

def test_e_fks(self):
def test_05_verify_foreign_keys(self):
o_tables, _, o_insp, d_insp = self.__get_tables_insp()

for table_name, _ in o_tables.items():
o_fks = o_insp.get_foreign_keys(table_name)
d_fks = d_insp.get_foreign_keys(table_name)
self.assertEqual(o_fks, d_fks)
assert o_fks == d_fks

def test_f_cks(self):
def test_06_verify_check_constraints(self):
o_tables, _, o_insp, d_insp = self.__get_tables_insp()

for table_name, _ in o_tables.items():
o_cks = o_insp.get_check_constraints(table_name)
d_cks = d_insp.get_check_constraints(table_name)
self.assertEqual(o_cks, d_cks)
assert o_cks == d_cks

@classmethod
def tearDownClass(cls):
os.remove("origin.db")
os.remove("dest.db")
def test_07_skip_table(self):
"""Test migration skipping compound_properties table"""
self.__gen_test_data()
migrator = DbMigrator(
self.origin, self.dest, exclude_tables=["compound_properties"]
)
assert migrator.migrate(chunk_size=10) is True

# Verify compound_properties was skipped
d_eng = create_engine(self.dest)
d_metadata = MetaData()
d_metadata.reflect(d_eng)
assert "compound_properties" not in d_metadata.tables

if __name__ == "__main__":
unittest.main()
# Verify other tables were migrated
assert "compound" in d_metadata.tables
assert "compound_structure" in d_metadata.tables

def test_08_skip_column(self):
"""Test migration skipping logp column in compound_properties"""
self.__gen_test_data()
migrator = DbMigrator(
self.origin, self.dest, exclude_fields=["compound_properties.logp"]
)
assert migrator.migrate(chunk_size=10) is True

# Verify logp column was skipped
d_eng = create_engine(self.dest)
d_metadata = MetaData()
d_metadata.reflect(d_eng)
props_table = d_metadata.tables["compound_properties"]
print(props_table.columns)
assert "logp" not in [column.name for column in props_table.columns]
assert "mw" in props_table.columns
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "cbl_migrator"
version = "0.3.7"
version = "0.3.8"
description = "Migrates Oracle dbs to PostgreSQL, MySQL and SQLite"
readme = "README.md"
license = { text = "MIT" }
Expand Down

0 comments on commit 11a6adc

Please sign in to comment.