From 67cdc171d3b7e8c8a8ccaa6ca8a2fccd92f9054c Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Sun, 8 Dec 2024 12:22:56 +0000 Subject: [PATCH] Refactored hash_columns with cleaner logic --- src/matchbox/common/db.py | 133 ++++++++++++++++++++------------------ 1 file changed, 70 insertions(+), 63 deletions(-) diff --git a/src/matchbox/common/db.py b/src/matchbox/common/db.py index ee864eb4..5d16bd5f 100644 --- a/src/matchbox/common/db.py +++ b/src/matchbox/common/db.py @@ -241,80 +241,87 @@ def __hash__(self) -> int: def hash_columns(cls, data: dict[str, Any]) -> "Source": """Shapes indices data from either the backend or TOML. - Handles: - * From TOML, no columns specified - * From TOML, some or all columns specified - * From the database, indices already present + Handles three scenarios: + 1. No columns specified - all columns except primary key are indexed + 2. Columns specified in TOML - uses specified columns with optional '*' + 3. Indices from database - uses existing column hash information """ - # Database setup - if isinstance(data["database"], SourceWarehouse): - warehouse = data["database"] - else: - warehouse = SourceWarehouse(**data["database"]) + # Initialise warehouse and get table metadata + warehouse = ( + data["database"] + if isinstance(data["database"], SourceWarehouse) + else SourceWarehouse(**data["database"]) + ) + metadata = MetaData(schema=data["db_schema"]) table = Table(data["db_table"], metadata, autoload_with=warehouse.engine) - # Column logic - # Get all locally specified columns, or remotely specified hashes - local_columns: list[SourceColumn] = [] - star_index: int | None = None - local_hashes: SourceIndex | None = None - index_data: dict | list | None = data.get("index") - if isinstance(index_data, dict): - # Came from Matchbox database - local_hashes = SourceIndex(**index_data) - else: - # Came from TOML - for column in index_data or []: - if column["literal"] == "*": - star_index = len(local_columns) - continue - local_columns.append(SourceColumn(**column, indexed=True)) - - # Get all remote columns using the user's creds and merge with local spec + # Get all columns except primary key remote_columns = [ SourceColumn(literal=col.name, type=str(col.type), indexed=False) for col in table.columns if col.name not in data["db_pk"] ] - db_columns: list[SourceColumn] = [] - db_indexed_columns: list[SourceColumn] = [] - db_non_indexed_columns: list[SourceColumn] = [] - - for remote_column in remote_columns: - if local_columns: - # Came from TOML, index and alias are configured from TOML - for local_column in local_columns: - if remote_column == local_column: - if local_column.type is None: - local_column.type = remote_column.type - db_indexed_columns.append(local_column) - break - else: - db_non_indexed_columns.append(remote_column) - elif local_hashes: - # Came from database, index is true when hashes match - if remote_column in local_hashes.literal + local_hashes.alias: - remote_column.indexed = True - db_columns.append(remote_column) - - if local_columns: - # Concatenate with TOML order, honouring star location (if present) - if star_index is not None: - db_columns = db_indexed_columns - for c in db_non_indexed_columns: - c.indexed = True - db_columns.insert(star_index, c) - else: - db_columns = db_indexed_columns + db_non_indexed_columns - - if not local_columns and not local_hashes: - # No columns specified, index all columns except the primary key - for col in remote_columns: + + index_data = data.get("index") + + # Case 1: No columns specified - index everything + if not index_data: + data["db_columns"] = [ + SourceColumn(literal=col.literal.name, type=col.type, indexed=True) + for col in remote_columns + ] + return data + + # Case 2: Columns from database + if isinstance(index_data, dict): + source_index = SourceIndex(**index_data) + data["db_columns"] = [ + SourceColumn( + literal=col.literal.name, + type=col.type, + indexed=col in source_index.literal + source_index.alias, + ) + for col in remote_columns + ] + return data + + # Case 3: Columns from TOML + local_columns = [] + star_index = None + + # Process TOML column specifications + for i, column in enumerate(index_data): + if column["literal"] == "*": + star_index = i + continue + local_columns.append(SourceColumn(**column, indexed=True)) + + # Match remote columns with local specifications + indexed_columns = [] + non_indexed_columns = [] + + for remote_col in remote_columns: + matched = False + for local_col in local_columns: + if remote_col == local_col: + if local_col.type is None: + local_col.type = remote_col.type + indexed_columns.append(local_col) + matched = True + break + if not matched: + non_indexed_columns.append(remote_col) + + # Handle wildcard insertion + if star_index is not None: + for col in non_indexed_columns: col.indexed = True - db_columns = remote_columns + indexed_columns[star_index:star_index] = non_indexed_columns + data["db_columns"] = indexed_columns + else: + data["db_columns"] = indexed_columns + non_indexed_columns - data["db_columns"] = db_columns return data def to_table(self) -> Table: