Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mark "available" functions for recording #761

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 63 additions & 5 deletions dbt-adapters/src/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
NotImplementedError,
UnexpectedNullError,
)
from dbt_common.record import auto_record_function, supports_replay, Recorder
from dbt_common.utils import (
AttrDict,
cast_to_str,
Expand Down Expand Up @@ -90,8 +91,14 @@
)
from dbt.adapters.protocol import AdapterConfig, MacroContextGeneratorCallable

if TYPE_CHECKING:
import agate
# if TYPE_CHECKING:
import agate

from dbt.adapters.record.serialization import AdapterExecuteSerializer

ExecuteReturn = Tuple[AdapterResponse, agate.Table]

Recorder.register_serialization_strategy(ExecuteReturn, AdapterExecuteSerializer())


GET_CATALOG_MACRO_NAME = "get_catalog"
Expand Down Expand Up @@ -215,6 +222,7 @@ class SnapshotStrategy(TypedDict):
hard_deletes: Optional[str]


@supports_replay
class BaseAdapter(metaclass=AdapterMeta):
"""The BaseAdapter provides an abstract base class for adapters.

Expand Down Expand Up @@ -378,13 +386,15 @@ def connection_named(
self.connections.query_header.reset()

@available.parse(_parse_callback_empty_table)
@auto_record_function("AdapterExecute", group="Available")
def execute(
self,
sql: str,
auto_begin: bool = False,
fetch: bool = False,
limit: Optional[int] = None,
) -> Tuple[AdapterResponse, "agate.Table"]:
) -> ExecuteReturn:

"""Execute the given SQL. This is a thin wrapper around
ConnectionManager.execute.

Expand All @@ -409,8 +419,10 @@ def validate_sql(self, sql: str) -> AdapterResponse:
"""
raise NotImplementedError("`validate_sql` is not implemented for this adapter!")

@auto_record_function("AdapterGetColumnSchemaFromQuery", group="Available")
@available.parse(lambda *a, **k: [])
def get_column_schema_from_query(self, sql: str) -> List[BaseColumn]:

"""Get a list of the Columns with names and data types from the given sql."""
_, cursor = self.connections.add_select_query(sql)
columns = [
Expand All @@ -422,8 +434,10 @@ def get_column_schema_from_query(self, sql: str) -> List[BaseColumn]:
]
return columns

@auto_record_function("AdapterGetPartitionsMetadata", group="Available")
@available.parse(_parse_callback_empty_table)
def get_partitions_metadata(self, table: str) -> Tuple["agate.Table"]:

"""
TODO: Can we move this to dbt-bigquery?
Obtain partitions metadata for a BigQuery partitioned table.
Expand Down Expand Up @@ -571,8 +585,10 @@ def set_relations_cache(
self.cache.clear()
self._relations_cache_for_schemas(relation_configs, required_schemas)

@auto_record_function("AdapterCacheAdded", group="Available")
@available
def cache_added(self, relation: Optional[BaseRelation]) -> str:

"""Cache a new relation in dbt. It will show up in `list relations`."""
if relation is None:
name = self.nice_connection_name()
Expand All @@ -581,8 +597,10 @@ def cache_added(self, relation: Optional[BaseRelation]) -> str:
# so jinja doesn't render things
return ""

@auto_record_function("AdapterCacheDropped", group="Available")
@available
def cache_dropped(self, relation: Optional[BaseRelation]) -> str:

"""Drop a relation in dbt. It will no longer show up in
`list relations`, and any bound views will be dropped from the cache
"""
Expand All @@ -592,6 +610,7 @@ def cache_dropped(self, relation: Optional[BaseRelation]) -> str:
self.cache.drop(relation)
return ""

@auto_record_function("AdapterCacheRenamed", group="Available")
@available
def cache_renamed(
self,
Expand Down Expand Up @@ -632,8 +651,10 @@ def list_schemas(self, database: str) -> List[str]:
"""Get a list of existing schemas in database"""
raise NotImplementedError("`list_schemas` is not implemented for this adapter!")

@auto_record_function("AdapterCheckSchemaExists", group="Available")
@available.parse(lambda *a, **k: False)
def check_schema_exists(self, database: str, schema: str) -> bool:

"""Check if a schema exists.

The default implementation of this is potentially unnecessarily slow,
Expand All @@ -646,6 +667,7 @@ def check_schema_exists(self, database: str, schema: str) -> bool:
###
# Abstract methods about relations
###
@auto_record_function("AdapterDropRelation", group="Available")
@abc.abstractmethod
@available.parse_none
def drop_relation(self, relation: BaseRelation) -> None:
Expand All @@ -655,12 +677,14 @@ def drop_relation(self, relation: BaseRelation) -> None:
"""
raise NotImplementedError("`drop_relation` is not implemented for this adapter!")

@auto_record_function("AdapterTruncateRelation", group="Available")
@abc.abstractmethod
@available.parse_none
def truncate_relation(self, relation: BaseRelation) -> None:
"""Truncate the given relation."""
raise NotImplementedError("`truncate_relation` is not implemented for this adapter!")

@auto_record_function("AdapterRenameRelation", group="Available")
@abc.abstractmethod
@available.parse_none
def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None:
Expand All @@ -670,6 +694,7 @@ def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation
"""
raise NotImplementedError("`rename_relation` is not implemented for this adapter!")

@auto_record_function("AdapterGetColumnsInRelation", group="Available")
@abc.abstractmethod
@available.parse_list
def get_columns_in_relation(self, relation: BaseRelation) -> List[BaseColumn]:
Expand All @@ -682,8 +707,10 @@ def get_catalog_for_single_relation(self, relation: BaseRelation) -> Optional[Ca
"`get_catalog_for_single_relation` is not implemented for this adapter!"
)

@auto_record_function("AdapterGetColumnsInTable", group="Available")
@available.deprecated("get_columns_in_relation", lambda *a, **k: [])
def get_columns_in_table(self, schema: str, identifier: str) -> List[BaseColumn]:

"""DEPRECATED: Get a list of the columns in the given table."""
relation = self.Relation.create(
database=self.config.credentials.database,
Expand Down Expand Up @@ -724,6 +751,7 @@ def list_relations_without_caching(self, schema_relation: BaseRelation) -> List[
###
# Methods about grants
###
@auto_record_function("AdapterStandardizeGrantsDict", group="Available")
@available
def standardize_grants_dict(self, grants_table: "agate.Table") -> dict:
"""Translate the result of `show grants` (or equivalent) to match the
Expand All @@ -738,6 +766,7 @@ def standardize_grants_dict(self, grants_table: "agate.Table") -> dict:
:return: A standardized dictionary matching the `grants` config
:rtype: dict
"""

grants_dict: Dict[str, List[str]] = {}
for row in grants_table:
grantee = row["grantee"]
Expand All @@ -751,10 +780,12 @@ def standardize_grants_dict(self, grants_table: "agate.Table") -> dict:
###
# Provided methods about relations
###
@auto_record_function("AdapterGetMissingColumns", group="Available")
@available.parse_list
def get_missing_columns(
self, from_relation: BaseRelation, to_relation: BaseRelation
) -> List[BaseColumn]:

"""Returns a list of Columns in from_relation that are missing from
to_relation.
"""
Expand Down Expand Up @@ -782,10 +813,12 @@ def get_missing_columns(

return [col for (col_name, col) in from_columns.items() if col_name in missing_columns]

@auto_record_function("AdapterValidSnapshotTarget", group="Available")
@available.parse_none
def valid_snapshot_target(
self, relation: BaseRelation, column_names: Optional[Dict[str, str]] = None
) -> None:

"""Ensure that the target relation is valid, by making sure it has the
expected columns.

Expand Down Expand Up @@ -814,10 +847,12 @@ def valid_snapshot_target(
if missing:
raise SnapshotTargetNotSnapshotTableError(missing)

@auto_record_function("AdapterAssertValidSnapshotTargetGivenStrategy", group="Available")
@available.parse_none
def assert_valid_snapshot_target_given_strategy(
self, relation: BaseRelation, column_names: Dict[str, str], strategy: SnapshotStrategy
) -> None:

# Assert everything we can with the legacy function.
self.valid_snapshot_target(relation, column_names)

Expand All @@ -836,10 +871,12 @@ def assert_valid_snapshot_target_given_strategy(
if missing:
raise SnapshotTargetNotSnapshotTableError(missing)

@auto_record_function("AdapterExpandTargetColumnTypes", group="Available")
@available.parse_none
def expand_target_column_types(
self, from_relation: BaseRelation, to_relation: BaseRelation
) -> None:

if not isinstance(from_relation, self.Relation):
raise MacroArgTypeError(
method_name="expand_target_column_types",
Expand Down Expand Up @@ -930,8 +967,10 @@ def _make_match(

return matches

@auto_record_function("AdapterGetRelation", group="Available")
@available.parse_none
def get_relation(self, database: str, schema: str, identifier: str) -> Optional[BaseRelation]:

relations_list = self.list_relations(database, schema)

matches = self._make_match(relations_list, database, schema, identifier)
Expand All @@ -949,9 +988,11 @@ def get_relation(self, database: str, schema: str, identifier: str) -> Optional[

return None

@auto_record_function("AdapterAlreadyExists", group="Available")
@available.deprecated("get_relation", lambda *a, **k: False)
def already_exists(self, schema: str, name: str) -> bool:
"""DEPRECATED: Return if a model already exists in the database"""

database = self.config.credentials.database
relation = self.get_relation(database, schema, name)
return relation is not None
Expand All @@ -960,12 +1001,14 @@ def already_exists(self, schema: str, name: str) -> bool:
# ODBC FUNCTIONS -- these should not need to change for every adapter,
# although some adapters may override them
###
@auto_record_function("AdapterCreateSchema", group="Available")
@abc.abstractmethod
@available.parse_none
def create_schema(self, relation: BaseRelation):
"""Create the given schema if it does not exist."""
raise NotImplementedError("`create_schema` is not implemented for this adapter!")

@auto_record_function("AdapterDropSchema", group="Available")
@abc.abstractmethod
@available.parse_none
def drop_schema(self, relation: BaseRelation):
Expand All @@ -974,11 +1017,13 @@ def drop_schema(self, relation: BaseRelation):

@available
@classmethod
@auto_record_function("AdapterQuote", group="Available")
@abc.abstractmethod
def quote(cls, identifier: str) -> str:
"""Quote the given identifier, as appropriate for the database."""
raise NotImplementedError("`quote` is not implemented for this adapter!")

@auto_record_function("AdapterQuoteAsConfigured", group="Available")
@available
def quote_as_configured(self, identifier: str, quote_key: str) -> str:
"""Quote or do not quote the given identifer as configured in the
Expand All @@ -987,6 +1032,7 @@ def quote_as_configured(self, identifier: str, quote_key: str) -> str:
The quote key should be one of 'database' (on bigquery, 'profile'),
'identifier', or 'schema', or it will be treated as if you set `True`.
"""

try:
key = ComponentName(quote_key)
except ValueError:
Expand All @@ -998,8 +1044,10 @@ def quote_as_configured(self, identifier: str, quote_key: str) -> str:
else:
return identifier

@auto_record_function("AdapterQuoteSeedColumn", group="Available")
@available
def quote_seed_column(self, column: str, quote_config: Optional[bool]) -> str:

quote_columns: bool = True
if isinstance(quote_config, bool):
quote_columns = quote_config
Expand Down Expand Up @@ -1102,7 +1150,9 @@ def convert_time_type(cls, agate_table: "agate.Table", col_idx: int) -> str:

@available
@classmethod
@auto_record_function("AdapterConvertType", group="Available")
def convert_type(cls, agate_table: "agate.Table", col_idx: int) -> Optional[str]:

return cls.convert_agate_type(agate_table, col_idx)

@classmethod
Expand Down Expand Up @@ -1612,6 +1662,8 @@ def valid_incremental_strategies(self):
"""
return ["append"]

"".format()

def builtin_incremental_strategies(self):
"""
List of possible builtin strategies for adapters
Expand Down Expand Up @@ -1704,7 +1756,9 @@ def render_column_constraint(cls, constraint: ColumnLevelConstraint) -> Optional

@available
@classmethod
def render_raw_columns_constraints(cls, raw_columns: Dict[str, Dict[str, Any]]) -> List:
@auto_record_function("AdapterRenderRawColumnConstraints", group="Available")
def render_raw_columns_constraints(cls, raw_columns: Dict[str, Dict[str, Any]]) -> List[str]:

rendered_column_constraints = []

for v in raw_columns.values():
Expand Down Expand Up @@ -1758,7 +1812,9 @@ def _parse_model_constraint(cls, raw_constraint: Dict[str, Any]) -> ModelLevelCo

@available
@classmethod
@auto_record_function("AdapterRenderRawModelConstraints", group="Available")
def render_raw_model_constraints(cls, raw_constraints: List[Dict[str, Any]]) -> List[str]:

return [c for c in map(cls.render_raw_model_constraint, raw_constraints) if c is not None]

@classmethod
Expand Down Expand Up @@ -1830,7 +1886,9 @@ def _get_adapter_specific_run_info(cls, config) -> Dict[str, Any]:

@available.parse_none
@classmethod
def get_hard_deletes_behavior(cls, config):
@auto_record_function("AdapterGetHardDeletesBehavior", group="Available")
def get_hard_deletes_behavior(cls, config: Dict[str, str]) -> str:

"""Check the hard_deletes config enum, and the legacy invalidate_hard_deletes
config flag in order to determine which behavior should be used for deleted
records in a snapshot. The default is to ignore them."""
Expand Down
24 changes: 24 additions & 0 deletions dbt-adapters/src/dbt/adapters/record/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import TYPE_CHECKING, Tuple

from dbt.adapters.contracts.connection import AdapterResponse

import agate

from mashumaro.types import SerializationStrategy

class AdapterExecuteSerializer(SerializationStrategy):
def serialize(self, table: Tuple[AdapterResponse, agate.Table]):
adapter_response, agate_table = table
return {
"adapter_response": adapter_response.to_dict(),
"table": {
"column_names": agate_table.column_names,
"column_types": [t.__class__.__name__ for t in agate_table.column_types],
"rows": list(map(list, agate_table.rows))
}
}

def deserialize(self, data):
adapter_response_dct, agate_table_dct = data
return None

Loading