Skip to content

Commit

Permalink
Refactored hash_columns with cleaner logic
Browse files Browse the repository at this point in the history
  • Loading branch information
wpfl-dbt committed Dec 8, 2024
1 parent c34f10d commit 67cdc17
Showing 1 changed file with 70 additions and 63 deletions.
133 changes: 70 additions & 63 deletions src/matchbox/common/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,80 +241,87 @@ def __hash__(self) -> int:
def hash_columns(cls, data: dict[str, Any]) -> "Source":
"""Shapes indices data from either the backend or TOML.
Handles:
* From TOML, no columns specified
* From TOML, some or all columns specified
* From the database, indices already present
Handles three scenarios:
1. No columns specified - all columns except primary key are indexed
2. Columns specified in TOML - uses specified columns with optional '*'
3. Indices from database - uses existing column hash information
"""
# Database setup
if isinstance(data["database"], SourceWarehouse):
warehouse = data["database"]
else:
warehouse = SourceWarehouse(**data["database"])
# Initialise warehouse and get table metadata
warehouse = (
data["database"]
if isinstance(data["database"], SourceWarehouse)
else SourceWarehouse(**data["database"])
)

metadata = MetaData(schema=data["db_schema"])
table = Table(data["db_table"], metadata, autoload_with=warehouse.engine)

# Column logic
# Get all locally specified columns, or remotely specified hashes
local_columns: list[SourceColumn] = []
star_index: int | None = None
local_hashes: SourceIndex | None = None
index_data: dict | list | None = data.get("index")
if isinstance(index_data, dict):
# Came from Matchbox database
local_hashes = SourceIndex(**index_data)
else:
# Came from TOML
for column in index_data or []:
if column["literal"] == "*":
star_index = len(local_columns)
continue
local_columns.append(SourceColumn(**column, indexed=True))

# Get all remote columns using the user's creds and merge with local spec
# Get all columns except primary key
remote_columns = [
SourceColumn(literal=col.name, type=str(col.type), indexed=False)
for col in table.columns
if col.name not in data["db_pk"]
]
db_columns: list[SourceColumn] = []
db_indexed_columns: list[SourceColumn] = []
db_non_indexed_columns: list[SourceColumn] = []

for remote_column in remote_columns:
if local_columns:
# Came from TOML, index and alias are configured from TOML
for local_column in local_columns:
if remote_column == local_column:
if local_column.type is None:
local_column.type = remote_column.type
db_indexed_columns.append(local_column)
break
else:
db_non_indexed_columns.append(remote_column)
elif local_hashes:
# Came from database, index is true when hashes match
if remote_column in local_hashes.literal + local_hashes.alias:
remote_column.indexed = True
db_columns.append(remote_column)

if local_columns:
# Concatenate with TOML order, honouring star location (if present)
if star_index is not None:
db_columns = db_indexed_columns
for c in db_non_indexed_columns:
c.indexed = True
db_columns.insert(star_index, c)
else:
db_columns = db_indexed_columns + db_non_indexed_columns

if not local_columns and not local_hashes:
# No columns specified, index all columns except the primary key
for col in remote_columns:

index_data = data.get("index")

# Case 1: No columns specified - index everything
if not index_data:
data["db_columns"] = [
SourceColumn(literal=col.literal.name, type=col.type, indexed=True)
for col in remote_columns
]
return data

# Case 2: Columns from database
if isinstance(index_data, dict):
source_index = SourceIndex(**index_data)
data["db_columns"] = [
SourceColumn(
literal=col.literal.name,
type=col.type,
indexed=col in source_index.literal + source_index.alias,
)
for col in remote_columns
]
return data

# Case 3: Columns from TOML
local_columns = []
star_index = None

# Process TOML column specifications
for i, column in enumerate(index_data):
if column["literal"] == "*":
star_index = i
continue
local_columns.append(SourceColumn(**column, indexed=True))

# Match remote columns with local specifications
indexed_columns = []
non_indexed_columns = []

for remote_col in remote_columns:
matched = False
for local_col in local_columns:
if remote_col == local_col:
if local_col.type is None:
local_col.type = remote_col.type
indexed_columns.append(local_col)
matched = True
break
if not matched:
non_indexed_columns.append(remote_col)

# Handle wildcard insertion
if star_index is not None:
for col in non_indexed_columns:
col.indexed = True
db_columns = remote_columns
indexed_columns[star_index:star_index] = non_indexed_columns
data["db_columns"] = indexed_columns
else:
data["db_columns"] = indexed_columns + non_indexed_columns

data["db_columns"] = db_columns
return data

def to_table(self) -> Table:
Expand Down

0 comments on commit 67cdc17

Please sign in to comment.