Skip to content

Commit

Permalink
try to prevent casing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
eloyfelix committed Jan 17, 2025
1 parent 5905470 commit 269f8d9
Showing 1 changed file with 58 additions and 36 deletions.
94 changes: 58 additions & 36 deletions cbl_migrator/migrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,15 @@ def __init__(
for item in exclude_fields:
table, field = item.split(".")
if table not in self.exclude_fields:
self.exclude_fields[table] = []
self.exclude_fields[table].append(field)
self.exclude_fields[table.lower()] = []
self.exclude_fields[table.lower()].append(field.lower())

o_eng = create_engine(self.o_eng_conn)
metadata = MetaData()
metadata.reflect(o_eng)
no_pk = [
t
for t, table in metadata.tables.items()
table_name.lower()
for table_name, table in metadata.tables.items()
if not list(table.primary_key.columns)
]
self.exclude_tables = exclude_tables + no_pk
Expand Down Expand Up @@ -197,9 +197,13 @@ def __copy_schema(self):
insp = inspect(o_eng)

new_metadata_tables = {}
tables = filter(
lambda x: x[0] not in self.exclude_tables, metadata.tables.items()
)
# Filter tables, excluding those in self.exclude_tables
tables = [
(name, table)
for name, table in metadata.tables.items()
if name.lower() not in self.exclude_tables
]

for table_name, table in tables:
# Keep only PK constraints unless it's SQLite
keep_constraints = [
Expand All @@ -209,12 +213,14 @@ def __copy_schema(self):
]
if d_eng.name == "sqlite":
# Retain all constraints for SQLite except those involving excluded fields
excluded_fields = self.exclude_fields.get(table_name, [])
excluded_fields = self.exclude_fields.get(table_name.lower(), [])

# Unique constraints
uks = insp.get_unique_constraints(table_name)
for uk in uks:
if not any(col in excluded_fields for col in uk["column_names"]):
if not any(
col.lower() in excluded_fields for col in uk["column_names"]
):
uk_cols = [
c for c in table._columns if c.name in uk["column_names"]
]
Expand All @@ -228,7 +234,9 @@ def __copy_schema(self):
for cons in table.constraints
if isinstance(cons, ForeignKeyConstraint)
]:
if not any(col.name in excluded_fields for col in fk.columns):
if not any(
col.name.lower() in excluded_fields for col in fk.columns
):
keep_constraints.append(fk)

# Check constraints
Expand All @@ -246,9 +254,9 @@ def __copy_schema(self):
table.indexes = set()

new_metadata_cols = ColumnCollection()
excluded_fields = self.exclude_fields.get(table_name, [])
excluded_fields = self.exclude_fields.get(table_name.lower(), [])
for col in table._columns:
if col.name not in excluded_fields:
if col.name.lower() not in excluded_fields:
col = self.__fix_column_type(col, o_eng.name, d_eng.name)
col.autoincrement = False
new_metadata_cols.add(col)
Expand All @@ -270,16 +278,17 @@ def validate_migration(self):
d_metadata = MetaData()
d_metadata.reflect(d_eng)

o_tables = {
t: tbl
for t, tbl in o_metadata.tables.items()
if t not in self.exclude_tables
}
d_tables = {
t: tbl
for t, tbl in d_metadata.tables.items()
if t not in self.exclude_tables
}
# Get origin tables, excluding those in exclude_tables
o_tables = {}
for table_name, table in o_metadata.tables.items():
if table_name.lower() not in self.exclude_tables:
o_tables[table_name] = table

# Get destination tables, excluding those in exclude_tables
d_tables = {}
for table_name, table in d_metadata.tables.items():
if table_name.lower() not in self.exclude_tables:
d_tables[table_name] = table

if set(o_tables.keys()) != set(d_tables.keys()):
return False
Expand Down Expand Up @@ -310,17 +319,23 @@ def __copy_constraints(self):
metadata.reflect(o_eng)
insp = inspect(o_eng)

tables = filter(
lambda x: x[0] not in self.exclude_tables, metadata.tables.items()
)
# Filter tables, excluding those in self.exclude_tables
tables = [
(name, table)
for name, table in metadata.tables.items()
if name.lower() not in self.exclude_tables
]

for table_name, table in tables:
constraints_to_keep = []
excluded_fields = self.exclude_fields.get(table_name, [])
excluded_fields = self.exclude_fields.get(table_name.lower(), [])

# Unique constraints - skip if any column is excluded
uks = insp.get_unique_constraints(table_name)
for uk in uks:
if not any(col in excluded_fields for col in uk["column_names"]):
if not any(
col.lower() in excluded_fields for col in uk["column_names"]
):
uk_cols = [
c for c in table._columns if c.name in uk["column_names"]
]
Expand All @@ -344,7 +359,7 @@ def __copy_constraints(self):
if isinstance(cons, ForeignKeyConstraint)
]
for fk in fks:
if not any(col.name in excluded_fields for col in fk.columns):
if not any(col.name.lower() in excluded_fields for col in fk.columns):
constraints_to_keep.append(fk)

# Create constraints
Expand All @@ -367,11 +382,15 @@ def __copy_indexes(self):
metadata.reflect(o_eng)
insp = inspect(o_eng)

tables = filter(
lambda x: x[0] not in self.exclude_tables, metadata.tables.items()
)
# Filter tables, excluding those in self.exclude_tables
tables = [
(name, table)
for name, table in metadata.tables.items()
if name.lower() not in self.exclude_tables
]

for table_name, table in tables:
excluded_fields = self.exclude_fields.get(table_name, [])
excluded_fields = self.exclude_fields.get(table_name.lower(), [])
uks = insp.get_unique_constraints(table_name)
pk = insp.get_pk_constraint(table_name)

Expand All @@ -380,7 +399,7 @@ def __copy_indexes(self):
for idx in table.indexes
if idx.name not in [u["name"] for u in uks]
and idx.name != pk["name"]
and not any(col.name in excluded_fields for col in idx.columns)
and not any(col.name.lower() in excluded_fields for col in idx.columns)
]
for index in indexes_to_keep:
try:
Expand Down Expand Up @@ -438,11 +457,14 @@ def migrate(
metadata = MetaData()
metadata.reflect(d_eng)
insp = inspect(d_eng)
# Get all table names excluding the ones in self.exclude_tables
all_tables_and_fks = insp.get_sorted_table_and_fkc_names()
table_names = [
t
for t, _ in insp.get_sorted_table_and_fkc_names()
if t and t not in self.exclude_tables
table_name
for table_name, _ in all_tables_and_fks
if table_name and table_name.lower() not in self.exclude_tables
]

tables = [metadata.tables[t] for t in table_names]

processes = 1 if d_eng.name == "sqlite" else self.n_cores
Expand Down

0 comments on commit 269f8d9

Please sign in to comment.