From c490dcd92ee50eef2e1f3771f0ca33f32b546216 Mon Sep 17 00:00:00 2001 From: Mike Gouline <1960272+gouline@users.noreply.github.com> Date: Fri, 20 Dec 2024 18:10:40 +1100 Subject: [PATCH] refactor: exposure extraction context (#300) --- dbtmetabase/_exposures.py | 276 ++++++++++++++++++++------------------ 1 file changed, 146 insertions(+), 130 deletions(-) diff --git a/dbtmetabase/_exposures.py b/dbtmetabase/_exposures.py index 99f6094..036902c 100644 --- a/dbtmetabase/_exposures.py +++ b/dbtmetabase/_exposures.py @@ -13,7 +13,9 @@ MutableSequence, Optional, Sequence, + Set, Tuple, + Union, ) from urllib.parse import unquote @@ -125,44 +127,41 @@ def dbname(details: Mapping) -> str: uid=collection["id"], models=("card", "dashboard"), ): + exposure = self.__Exposure( + model=item["model"], + uid=item["id"], + label="Exposure [Unresolved Name]", + ) + if ( exclude_unverified - and item["model"] == "card" + and exposure.model == "card" and item.get("moderated_status") != "verified" ): _logger.debug("Skipping unverified card '%s'", item["name"]) continue - depends = set() - native_query = "" - header = "" - average_query_time = None - last_used_at = None - entity: Mapping - if item["model"] == "card": + if exposure.model == "card": card_entity = self.metabase.find_card(uid=item["id"]) if card_entity is None: _logger.info("Card '%s' not found, skipping", item["id"]) continue entity = card_entity - header = ( + exposure.header = ( f"Visualization: {entity.get('display', 'Unknown').title()}" ) - result = self.__extract_card_exposures(ctx, entity) - depends.update(result["depends"]) - native_query = result["native_query"] + self.__exposure_card(ctx, exposure, entity) - average_query_time_ms = entity.get("average_query_time") - if average_query_time_ms: + if average_query_time_ms := entity.get("average_query_time"): average_query_time_s = average_query_time_ms / 1000 - average_query_time = f"{(average_query_time_s // 60):.0f}:{(average_query_time_s % 60):06.3f}" + exposure.average_query_time = f"{(average_query_time_s // 60):.0f}:{(average_query_time_s % 60):06.3f}" - last_used_at = entity.get("last_used_at") + exposure.last_used_at = entity.get("last_used_at") - elif item["model"] == "dashboard": + elif exposure.model == "dashboard": dashboard_entity = self.metabase.find_dashboard(uid=item["id"]) if dashboard_entity is None: _logger.info("Dashboard '%s' not found, skipping", item["id"]) @@ -173,37 +172,35 @@ def dbname(details: Mapping) -> str: if not cards: continue - header = f"Dashboard Cards: {len(cards)}" + exposure.header = f"Dashboard Cards: {len(cards)}" for card_ref in cards: card = card_ref.get("card", {}) if "id" not in card: continue if card := self.metabase.find_card(uid=card["id"]): - result = self.__extract_card_exposures(ctx, card) - depends.update(result["depends"]) + self.__exposure_card(ctx, exposure, card) else: _logger.warning("Unexpected collection item '%s'", item["model"]) continue - name = entity.get("name", "Exposure [Unresolved Name]") - _logger.info("Processing %s '%s'", item["model"], name) + exposure.label = entity.get("name") or exposure.label + exposure.description = entity.get("description") or exposure.description + exposure.created_at = entity["created_at"] + _logger.info("Processing %s '%s'", exposure.model, exposure.label) - creator_name = None - creator_email = None if "creator" in entity: - creator_email = entity["creator"]["email"] - creator_name = entity["creator"]["common_name"] + exposure.creator_email = entity["creator"]["email"] + exposure.creator_name = entity["creator"]["common_name"] elif "creator_id" in entity: - creator = self.metabase.find_user(uid=entity["creator_id"]) - if creator: - creator_name = creator.get("common_name") - creator_email = creator.get("email") + if creator := self.metabase.find_user(uid=entity["creator_id"]): + exposure.creator_name = creator.get("common_name", "") + exposure.creator_email = creator.get("email", "") - label = name - name = safe_name(name) - count = counts.get(name, 0) - counts[name] = count + 1 + exposure.name = safe_name(exposure.label) + count = counts.get(exposure.name, 0) + counts[exposure.name] = count + 1 + exposure.name = exposure.name + (f"_{count}" if count > 0 else "") exposures.append( { @@ -211,22 +208,22 @@ def dbname(details: Mapping) -> str: "type": item["model"], "collection": collection_slug, "body": self.__format_exposure( - model=item["model"], - uid=item["id"], - name=name + (f"_{count}" if count > 0 else ""), - label=label, - header=header, - description=entity.get("description", ""), - created_at=entity["created_at"], - creator_name=creator_name or "", - creator_email=creator_email or "", - last_used_at=last_used_at, - average_query_time=average_query_time, - native_query=native_query, + model=exposure.model, + uid=exposure.uid, + name=exposure.name, + label=exposure.label, + header=exposure.header, + description=exposure.description, + created_at=exposure.created_at, + creator_name=exposure.creator_name, + creator_email=exposure.creator_email, + last_used_at=exposure.last_used_at, + average_query_time=exposure.average_query_time, + native_query=exposure.native_query, depends_on=sorted( [ ctx.model_refs[depend.lower()] - for depend in depends + for depend in exposure.depends if depend.lower() in ctx.model_refs ] ), @@ -239,98 +236,99 @@ def dbname(details: Mapping) -> str: return exposures - def __extract_card_exposures(self, ctx: __Context, card: Mapping) -> Mapping: + def __exposure_card(self, ctx: __Context, exposure: __Exposure, card: Mapping): """Extracts exposures from Metabase questions.""" - depends = set() - native_query = "" + dataset_query = card.get("dataset_query", {}) + card_type = dataset_query.get("type") + if card_type == "query": + self.__exposure_query(ctx, exposure, card) + elif card_type == "native": + self.__exposure_native(ctx, exposure, card) + else: + _logger.warning("Unsupported card type '%s'", card_type) + + def __exposure_query(self, ctx: __Context, exposure: __Exposure, card: Mapping): + """Extracts exposures from Metabase GUI queries.""" + + dataset_query = card.get("dataset_query", {}) + query = dataset_query.get("query", {}) + + query_source: Union[str, int] = query.get("source-table", card.get("table_id")) + if isinstance(query_source, str) and query_source.startswith("card__"): + # Question based on another question + source_card_uid = query_source.split("__")[-1] + if source_card := self.metabase.find_card(uid=source_card_uid): + self.__exposure_card(ctx, exposure, source_card) + + elif isinstance(query_source, int) and query_source in ctx.table_names: + # Question based on table + source_table = ctx.table_names[query_source].lower() + exposure.depends.add(source_table) + _logger.info("Extracted model '%s' from card", source_table) + + # Find models in joins + for join in query.get("joins", []): + join_source: Union[str, int] = join.get("source-table") + if isinstance(join_source, str) and join_source.startswith("card__"): + # Question based on another question + source_card_uid = join_source.split("__")[-1] + if source_card := self.metabase.find_card(uid=source_card_uid): + self.__exposure_card(ctx, exposure, source_card) - query = card.get("dataset_query", {}) - if query.get("type") == "query": - # Metabase GUI derived query - query_source = query.get("query", {}).get( - "source-table", card.get("table_id") - ) + continue - if str(query_source).startswith("card__"): - # Handle questions based on other questions - if source_card := self.metabase.find_card( - uid=query_source.split("__")[-1] - ): - result = self.__extract_card_exposures(ctx, source_card) - depends.update(result["depends"]) - elif query_source in ctx.table_names: - # Normal question - source_table = ctx.table_names.get(query_source) - if source_table: - source_table = source_table.lower() - _logger.info("Extracted model '%s' from card", source_table) - depends.add(source_table) - - # Find models exposed through joins - for join in query.get("query", {}).get("joins", []): - join_source = join.get("source-table") - - if str(join_source).startswith("card__"): - # Handle questions based on other questions - if source_card := self.metabase.find_card( - uid=join_source.split("__")[-1] - ): - result = self.__extract_card_exposures(ctx, source_card) - depends.update(result["depends"]) + elif isinstance(join_source, int) and join_source in ctx.table_names: + # Joined model parsed + joined_table = ctx.table_names[join_source].lower() + exposure.depends.add(joined_table) + _logger.info("Extracted model '%s' from join", joined_table) - continue + def __exposure_native(self, ctx: __Context, exposure: __Exposure, card: Mapping): + """Extracts exposures from Metabase native queries.""" - # Joined model parsed - joined_table = ctx.table_names.get(join_source) - if joined_table: - joined_table = joined_table.lower() - _logger.info("Extracted model '%s' from join", joined_table) - depends.add(joined_table) - - elif query.get("type") == "native": - # Metabase native query - native_query = query["native"].get("query") - ctes: MutableSequence[str] = [] - - # Parse common table expressions for exclusion - for matched_cte in re.findall(_CTE_PARSER, native_query): - ctes.extend(group.lower() for group in matched_cte if group) - - # Parse SQL for exposures through FROM or JOIN clauses - for sql_ref in re.findall(_EXPOSURE_PARSER, native_query): - sql_ref = sql_ref.strip("`") # BigQuery uses backticks `dataset.table` - - # DATABASE.schema.table -> [database, schema, table] - parsed_model_path = [s.strip('"').lower() for s in sql_ref.split(".")] - - # Scrub CTEs (qualified sql_refs can not reference CTEs) - if parsed_model_path[-1] in ctes and "." not in sql_ref: - continue + dataset_query = card.get("dataset_query", {}) + database = dataset_query["database"] + native_query = dataset_query["native"]["query"] - # Missing schema -> use default schema - if len(parsed_model_path) < 2: - parsed_model_path.insert(0, DEFAULT_SCHEMA.lower()) - # Missing database -> use query's database - if len(parsed_model_path) < 3: - database_name = ctx.database_names.get(query["database"], "") - parsed_model_path.insert(0, database_name.lower()) + # Parse common table expressions for exclusion + ctes: MutableSequence[str] = [] + for matched_cte in re.findall(_CTE_PARSER, native_query): + ctes.extend(group.lower() for group in matched_cte if group) - # Fully-qualified database.schema.table - parsed_model = ".".join(parsed_model_path) + # Parse SQL for exposures through FROM or JOIN clauses + for sql_ref in re.findall(_EXPOSURE_PARSER, native_query): + sql_ref = sql_ref.strip("`") # BigQuery uses backticks `dataset.table` - # Verify this is one of our parsed refable models so exposures dont break the DAG - if not ctx.model_refs.get(parsed_model): - continue + # DATABASE.schema.table -> [database, schema, table] + parsed_model_path = [s.strip('"').lower() for s in sql_ref.split(".")] - if parsed_model: - _logger.info("Extracted model '%s' from native query", parsed_model) - depends.add(parsed_model) + # Scrub CTEs (qualified sql_refs can not reference CTEs) + if parsed_model_path[-1] in ctes and "." not in sql_ref: + continue - return { - "depends": depends, - "native_query": native_query, - } + # Missing schema -> use default schema + if len(parsed_model_path) < 2: + parsed_model_path.insert(0, DEFAULT_SCHEMA.lower()) + # Missing database -> use query's database + if len(parsed_model_path) < 3: + database_name = ctx.database_names.get(database, "") + parsed_model_path.insert(0, database_name.lower()) + + # Fully-qualified database.schema.table + parsed_model = ".".join(parsed_model_path) + + # Verify this is one of our parsed refable models so exposures dont break the DAG + if not ctx.model_refs.get(parsed_model): + continue + + if parsed_model: + exposure.depends.add(parsed_model) + _logger.info("Extracted model '%s' from native query", parsed_model) + + if exposure.model != "dashboard": + # Only include SQL for query exposures + exposure.native_query = native_query def __format_exposure( self, @@ -374,6 +372,8 @@ def __format_exposure( # Format query into markdown code block native_query = "\n".join(x for x in native_query.split("\n") if x.strip()) native_query = f"#### Query\n\n```\n{native_query}\n```\n\n" + else: + native_query = "" metadata = ( "#### Metadata\n\n" @@ -451,5 +451,21 @@ def __write_exposures( @dc.dataclass class __Context: model_refs: Mapping[str, str] - database_names: Mapping[str, str] - table_names: Mapping[str, str] + database_names: Mapping[int, str] + table_names: Mapping[int, str] + + @dc.dataclass + class __Exposure: + model: str + uid: str + label: str + name: str = "" + description: str = "" + created_at: str = "" + header: str = "" + creator_name: str = "" + creator_email: str = "" + average_query_time: Optional[str] = None + last_used_at: Optional[str] = None + native_query: Optional[str] = None + depends: Set[str] = dc.field(default_factory=set)