diff --git a/dlt/common/destination/client.py b/dlt/common/destination/client.py new file mode 100644 index 0000000000..28de638c84 --- /dev/null +++ b/dlt/common/destination/client.py @@ -0,0 +1,632 @@ +from abc import ABC, abstractmethod +import dataclasses + +from types import TracebackType +from typing import ( + Optional, + NamedTuple, + Literal, + Sequence, + Iterable, + Type, + List, + ContextManager, + Dict, + Any, + TypeVar, +) +from typing_extensions import Annotated +import datetime # noqa: 251 + +from dlt.common import logger, pendulum +from dlt.common.configuration.specs.base_configuration import extract_inner_hint +from dlt.common.destination.typing import PreparedTableSchema +from dlt.common.destination.utils import verify_schema_capabilities, verify_supported_data_types +from dlt.common.exceptions import TerminalException +from dlt.common.metrics import LoadJobMetrics +from dlt.common.normalizers.naming import NamingConvention + +from dlt.common.schema import Schema, TSchemaTables +from dlt.common.schema.typing import ( + C_DLT_LOAD_ID, + TLoaderReplaceStrategy, +) +from dlt.common.schema.utils import fill_hints_from_parent_and_clone_table + +from dlt.common.configuration import configspec, NotResolved +from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration +from dlt.common.destination.capabilities import DestinationCapabilitiesContext +from dlt.common.destination.exceptions import ( + DestinationSchemaTampered, + DestinationTransientException, +) +from dlt.common.schema.exceptions import UnknownTableException +from dlt.common.storages import FileStorage +from dlt.common.storages.load_storage import ParsedLoadJobFileName +from dlt.common.storages.load_package import LoadJobInfo, TPipelineStateDoc +from dlt.common.typing import is_optional_type + +TDestinationDwhClient = TypeVar("TDestinationDwhClient", bound="DestinationClientDwhConfiguration") + +DEFAULT_FILE_LAYOUT = "{table_name}/{load_id}.{file_id}.{ext}" + + +class StorageSchemaInfo(NamedTuple): + version_hash: str + schema_name: str + version: int + engine_version: str + inserted_at: datetime.datetime + schema: str + + @classmethod + def from_normalized_mapping( + cls, normalized_doc: Dict[str, Any], naming_convention: NamingConvention + ) -> "StorageSchemaInfo": + """Instantiate this class from mapping where keys are normalized according to given naming convention + + Args: + normalized_doc: Mapping with normalized keys (e.g. {Version: ..., SchemaName: ...}) + naming_convention: Naming convention that was used to normalize keys + + Returns: + StorageSchemaInfo: Instance of this class + """ + return cls( + version_hash=normalized_doc[naming_convention.normalize_identifier("version_hash")], + schema_name=normalized_doc[naming_convention.normalize_identifier("schema_name")], + version=normalized_doc[naming_convention.normalize_identifier("version")], + engine_version=normalized_doc[naming_convention.normalize_identifier("engine_version")], + inserted_at=normalized_doc[naming_convention.normalize_identifier("inserted_at")], + schema=normalized_doc[naming_convention.normalize_identifier("schema")], + ) + + def to_normalized_mapping(self, naming_convention: NamingConvention) -> Dict[str, Any]: + """Convert this instance to mapping where keys are normalized according to given naming convention + + Args: + naming_convention: Naming convention that should be used to normalize keys + + Returns: + Dict[str, Any]: Mapping with normalized keys (e.g. {Version: ..., SchemaName: ...}) + """ + return { + naming_convention.normalize_identifier(key): value + for key, value in self._asdict().items() + } + + +@dataclasses.dataclass +class StateInfo: + version: int + engine_version: int + pipeline_name: str + state: str + created_at: datetime.datetime + version_hash: Optional[str] = None + _dlt_load_id: Optional[str] = None + + def as_doc(self) -> TPipelineStateDoc: + doc: TPipelineStateDoc = dataclasses.asdict(self) # type: ignore[assignment] + if self._dlt_load_id is None: + doc.pop(C_DLT_LOAD_ID) # type: ignore[misc] + if self.version_hash is None: + doc.pop("version_hash") + return doc + + @classmethod + def from_normalized_mapping( + cls, normalized_doc: Dict[str, Any], naming_convention: NamingConvention + ) -> "StateInfo": + """Instantiate this class from mapping where keys are normalized according to given naming convention + + Args: + normalized_doc: Mapping with normalized keys (e.g. {Version: ..., PipelineName: ...}) + naming_convention: Naming convention that was used to normalize keys + + Returns: + StateInfo: Instance of this class + """ + return cls( + version=normalized_doc[naming_convention.normalize_identifier("version")], + engine_version=normalized_doc[naming_convention.normalize_identifier("engine_version")], + pipeline_name=normalized_doc[naming_convention.normalize_identifier("pipeline_name")], + state=normalized_doc[naming_convention.normalize_identifier("state")], + created_at=normalized_doc[naming_convention.normalize_identifier("created_at")], + version_hash=normalized_doc.get(naming_convention.normalize_identifier("version_hash")), + _dlt_load_id=normalized_doc.get(naming_convention.normalize_identifier(C_DLT_LOAD_ID)), + ) + + +@configspec +class DestinationClientConfiguration(BaseConfiguration): + destination_type: Annotated[str, NotResolved()] = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) # which destination to load data to + credentials: Optional[CredentialsConfiguration] = None + destination_name: Optional[str] = ( + None # name of the destination, if not set, destination_type is used + ) + environment: Optional[str] = None + + def fingerprint(self) -> str: + """Returns a destination fingerprint which is a hash of selected configuration fields. ie. host in case of connection string""" + return "" + + def __str__(self) -> str: + """Return displayable destination location""" + return str(self.credentials) + + def on_resolved(self) -> None: + self.destination_name = self.destination_name or self.destination_type + + @classmethod + def credentials_type( + cls, config: "DestinationClientConfiguration" = None + ) -> Type[CredentialsConfiguration]: + """Figure out credentials type, using hint resolvers for dynamic types + + For correct type resolution of filesystem, config should have bucket_url populated + """ + key = "credentials" + type_ = cls.get_resolvable_fields()[key] + if key in cls.__hint_resolvers__ and config is not None: + try: + # Type hint for this field is created dynamically + type_ = cls.__hint_resolvers__[key](config) + except Exception: + # we suppress failed hint resolutions + pass + return extract_inner_hint(type_) + + +@configspec +class DestinationClientDwhConfiguration(DestinationClientConfiguration): + """Configuration of a destination that supports datasets/schemas""" + + dataset_name: Annotated[str, NotResolved()] = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) # dataset cannot be resolved + """dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix""" + default_schema_name: Annotated[Optional[str], NotResolved()] = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) + """name of default schema to be used to name effective dataset to load data to""" + replace_strategy: TLoaderReplaceStrategy = "truncate-and-insert" + """How to handle replace disposition for this destination, can be classic or staging""" + staging_dataset_name_layout: str = "%s_staging" + """Layout for staging dataset, where %s is replaced with dataset name. placeholder is optional""" + enable_dataset_name_normalization: bool = True + """Whether to normalize the dataset name. Affects staging dataset as well.""" + + def _bind_dataset_name( + self: TDestinationDwhClient, dataset_name: str, default_schema_name: str = None + ) -> TDestinationDwhClient: + """Binds the dataset and default schema name to the configuration + + This method is intended to be used internally. + """ + self.dataset_name = dataset_name + self.default_schema_name = default_schema_name + return self + + def normalize_dataset_name(self, schema: Schema) -> str: + """Builds full db dataset (schema) name out of configured dataset name and schema name: {dataset_name}_{schema.name}. The resulting name is normalized. + + If default schema name is None or equals schema.name, the schema suffix is skipped. + """ + dataset_name = self._make_dataset_name(schema.name) + if not dataset_name: + return dataset_name + else: + return ( + schema.naming.normalize_table_identifier(dataset_name) + if self.enable_dataset_name_normalization + else dataset_name + ) + + def normalize_staging_dataset_name(self, schema: Schema) -> str: + """Builds staging dataset name out of dataset_name and staging_dataset_name_layout.""" + if "%s" in self.staging_dataset_name_layout: + # staging dataset name is never empty, otherwise table names must clash + dataset_name = self._make_dataset_name(schema.name) + # fill the placeholder + dataset_name = self.staging_dataset_name_layout % (dataset_name or "") + else: + # no placeholder, then layout is a full name. so you can have a single staging dataset + dataset_name = self.staging_dataset_name_layout + + return ( + schema.naming.normalize_table_identifier(dataset_name) + if self.enable_dataset_name_normalization + else dataset_name + ) + + @classmethod + def needs_dataset_name(cls) -> bool: + """Checks if configuration requires dataset name to be present. Empty datasets are allowed + ie. for schema-less destinations like weaviate or clickhouse + """ + fields = cls.get_resolvable_fields() + dataset_name_type = fields["dataset_name"] + return not is_optional_type(dataset_name_type) + + def _make_dataset_name(self, schema_name: str) -> str: + if not schema_name: + raise ValueError("schema_name is None or empty") + + # if default schema is None then suffix is not added + if self.default_schema_name is not None and schema_name != self.default_schema_name: + return (self.dataset_name or "") + "_" + schema_name + return self.dataset_name + + +@configspec +class DestinationClientStagingConfiguration(DestinationClientDwhConfiguration): + """Configuration of a staging destination, able to store files with desired `layout` at `bucket_url`. + + Also supports datasets and can act as standalone destination. + """ + + as_staging_destination: bool = False + bucket_url: str = None + # layout of the destination files + layout: str = DEFAULT_FILE_LAYOUT + + +@configspec +class DestinationClientDwhWithStagingConfiguration(DestinationClientDwhConfiguration): + """Configuration of a destination that can take data from staging destination""" + + staging_config: Optional[DestinationClientStagingConfiguration] = None + """configuration of the staging, if present, injected at runtime""" + truncate_tables_on_staging_destination_before_load: bool = True + """If dlt should truncate the tables on staging destination before loading data.""" + + +TLoadJobState = Literal["ready", "running", "failed", "retry", "completed"] + + +class LoadJob(ABC): + """ + A stateful load job, represents one job file + """ + + def __init__(self, file_path: str) -> None: + self._file_path = file_path + self._file_name = FileStorage.get_file_name_from_file_path(file_path) + # NOTE: we only accept a full filepath in the constructor + assert self._file_name != self._file_path + self._parsed_file_name = ParsedLoadJobFileName.parse(self._file_name) + self._started_at: pendulum.DateTime = None + self._finished_at: pendulum.DateTime = None + + def job_id(self) -> str: + """The job id that is derived from the file name and does not changes during job lifecycle""" + return self._parsed_file_name.job_id() + + def file_name(self) -> str: + """A name of the job file""" + return self._file_name + + def job_file_info(self) -> ParsedLoadJobFileName: + return self._parsed_file_name + + @abstractmethod + def state(self) -> TLoadJobState: + """Returns current state. Should poll external resource if necessary.""" + pass + + @abstractmethod + def exception(self) -> str: + """The exception associated with failed or retry states""" + pass + + def metrics(self) -> Optional[LoadJobMetrics]: + """Returns job execution metrics""" + return LoadJobMetrics( + self._parsed_file_name.job_id(), + self._file_path, + self._parsed_file_name.table_name, + self._started_at, + self._finished_at, + self.state(), + None, + ) + + +class RunnableLoadJob(LoadJob, ABC): + """Represents a runnable job that loads a single file + + Each job starts in "running" state and ends in one of terminal states: "retry", "failed" or "completed". + Each job is uniquely identified by a file name. The file is guaranteed to exist in "running" state. In terminal state, the file may not be present. + In "running" state, the loader component periodically gets the state via `status()` method. When terminal state is reached, load job is discarded and not called again. + `exception` method is called to get error information in "failed" and "retry" states. + + The `__init__` method is responsible to put the Job in "running" state. It may raise `LoadClientTerminalException` and `LoadClientTransientException` to + immediately transition job into "failed" or "retry" state respectively. + """ + + def __init__(self, file_path: str) -> None: + """ + File name is also a job id (or job id is deterministically derived) so it must be globally unique + """ + # ensure file name + super().__init__(file_path) + self._state: TLoadJobState = "ready" + self._exception: BaseException = None + + # variables needed by most jobs, set by the loader in set_run_vars + self._schema: Schema = None + self._load_table: PreparedTableSchema = None + self._load_id: str = None + self._job_client: "JobClientBase" = None + + def set_run_vars(self, load_id: str, schema: Schema, load_table: PreparedTableSchema) -> None: + """ + called by the loader right before the job is run + """ + self._load_id = load_id + self._schema = schema + self._load_table = load_table + + @property + def load_table_name(self) -> str: + return self._load_table["name"] + + def run_managed( + self, + job_client: "JobClientBase", + ) -> None: + """ + wrapper around the user implemented run method + """ + from dlt.common.runtime import signals + + # only jobs that are not running or have not reached a final state + # may be started + assert self._state in ("ready", "retry") + self._job_client = job_client + + # filepath is now moved to running + try: + self._state = "running" + self._started_at = pendulum.now() + self._job_client.prepare_load_job_execution(self) + self.run() + self._state = "completed" + except (TerminalException, AssertionError) as e: + self._state = "failed" + self._exception = e + logger.exception(f"Terminal exception in job {self.job_id()} in file {self._file_path}") + except (DestinationTransientException, Exception) as e: + self._state = "retry" + self._exception = e + logger.exception( + f"Transient exception in job {self.job_id()} in file {self._file_path}" + ) + finally: + self._finished_at = pendulum.now() + # sanity check + assert self._state in ("completed", "retry", "failed") + if self._state != "retry": + # wake up waiting threads + signals.wake_all() + + @abstractmethod + def run(self) -> None: + """ + run the actual job, this will be executed on a thread and should be implemented by the user + exception will be handled outside of this function + """ + raise NotImplementedError() + + def state(self) -> TLoadJobState: + """Returns current state. Should poll external resource if necessary.""" + return self._state + + def exception(self) -> str: + """The exception associated with failed or retry states""" + return str(self._exception) + + +class FollowupJobRequest: + """Base class for follow up jobs that should be created""" + + @abstractmethod + def new_file_path(self) -> str: + """Path to a newly created temporary job file. If empty, no followup job should be created""" + pass + + +class HasFollowupJobs: + """Adds a trait that allows to create single or table chain followup jobs""" + + def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJobRequest]: + """Return list of jobs requests for jobs that should be created. `final_state` is state to which this job transits""" + return [] + + +class JobClientBase(ABC): + def __init__( + self, + schema: Schema, + config: DestinationClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + self.schema = schema + self.config = config + self.capabilities = capabilities + + @abstractmethod + def initialize_storage(self, truncate_tables: Optional[Iterable[str]] = None) -> None: + """Prepares storage to be used ie. creates database schema or file system folder. Truncates requested tables.""" + pass + + @abstractmethod + def is_storage_initialized(self) -> bool: + """Returns if storage is ready to be read/written.""" + pass + + @abstractmethod + def drop_storage(self) -> None: + """Brings storage back into not initialized state. Typically data in storage is destroyed.""" + pass + + def verify_schema( + self, only_tables: Iterable[str] = None, new_jobs: Iterable[ParsedLoadJobFileName] = None + ) -> List[PreparedTableSchema]: + """Verifies schema before loading, returns a list of verified loaded tables.""" + if exceptions := verify_schema_capabilities( + self.schema, + self.capabilities, + self.config.destination_type, + warnings=False, + ): + for exception in exceptions: + logger.error(str(exception)) + raise exceptions[0] + + prepared_tables = [ + self.prepare_load_table(table_name) + for table_name in set( + list(only_tables or []) + self.schema.data_table_names(seen_data_only=True) + ) + ] + if exceptions := verify_supported_data_types( + prepared_tables, + new_jobs, + self.capabilities, + self.config.destination_type, + warnings=False, + ): + for exception in exceptions: + logger.error(str(exception)) + raise exceptions[0] + return prepared_tables + + def update_stored_schema( + self, + only_tables: Iterable[str] = None, + expected_update: TSchemaTables = None, + ) -> Optional[TSchemaTables]: + """Updates storage to the current schema. + + Implementations should not assume that `expected_update` is the exact difference between destination state and the self.schema. This is only the case if + destination has single writer and no other processes modify the schema. + + Args: + only_tables (Sequence[str], optional): Updates only listed tables. Defaults to None. + expected_update (TSchemaTables, optional): Update that is expected to be applied to the destination + Returns: + Optional[TSchemaTables]: Returns an update that was applied at the destination. + """ + # make sure that schema being saved was not modified from the moment it was loaded from storage + version_hash = self.schema.version_hash + if self.schema.is_modified: + raise DestinationSchemaTampered( + self.schema.name, version_hash, self.schema.stored_version_hash + ) + return expected_update + + def prepare_load_table(self, table_name: str) -> PreparedTableSchema: + """Prepares a table schema to be loaded by filling missing hints and doing other modifications requires by given destination.""" + try: + return fill_hints_from_parent_and_clone_table(self.schema.tables, self.schema.tables[table_name]) # type: ignore[return-value] + + except KeyError: + raise UnknownTableException(self.schema.name, table_name) + + @abstractmethod + def create_load_job( + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + """Creates a load job for a particular `table` with content in `file_path`. Table is already prepared to be loaded.""" + pass + + def prepare_load_job_execution( # noqa: B027, optional override + self, job: RunnableLoadJob + ) -> None: + """Prepare the connected job client for the execution of a load job (used for query tags in sql clients)""" + pass + + def should_truncate_table_before_load(self, table_name: str) -> bool: + return self.prepare_load_table(table_name)["write_disposition"] == "replace" + + def create_table_chain_completed_followup_jobs( + self, + table_chain: Sequence[PreparedTableSchema], + completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, + ) -> List[FollowupJobRequest]: + """Creates a list of followup jobs that should be executed after a table chain is completed. Tables are already prepared to be loaded.""" + return [] + + @abstractmethod + def complete_load(self, load_id: str) -> None: + """Marks the load package with `load_id` as completed in the destination. Before such commit is done, the data with `load_id` is invalid.""" + pass + + @abstractmethod + def __enter__(self) -> "JobClientBase": + pass + + @abstractmethod + def __exit__( + self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType + ) -> None: + pass + + +class WithStateSync(ABC): + @abstractmethod + def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: + """ + Retrieves newest schema with given name from destination storage + If no name is provided, the newest schema found is retrieved. + """ + pass + + @abstractmethod + def get_stored_schema_by_hash(self, version_hash: str) -> StorageSchemaInfo: + """retrieves the stored schema by hash""" + pass + + @abstractmethod + def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: + """Loads compressed state from destination storage""" + pass + + +class WithStagingDataset(ABC): + """Adds capability to use staging dataset and request it from the loader""" + + @abstractmethod + def should_load_data_to_staging_dataset(self, table_name: str) -> bool: + return False + + @abstractmethod + def with_staging_dataset(self) -> ContextManager["JobClientBase"]: + """Executes job client methods on staging dataset""" + return self # type: ignore + + +class SupportsStagingDestination(ABC): + """Adds capability to support a staging destination for the load""" + + def should_load_data_to_staging_dataset_on_staging_destination(self, table_name: str) -> bool: + """If set to True, and staging destination is configured, the data will be loaded to staging dataset on staging destination + instead of a regular dataset on staging destination. Currently it is used by Athena Iceberg which uses staging dataset + on staging destination to copy data to iceberg tables stored on regular dataset on staging destination. + The default is to load data to regular dataset on staging destination from where warehouses like Snowflake (that have their + own storage) will copy data. + """ + return False + + @abstractmethod + def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: + """If set to True, data in `table` will be truncated on staging destination (regular dataset). This is the default behavior which + can be changed with a config flag. + For Athena + Iceberg this setting is always False - Athena uses regular dataset to store Iceberg tables and we avoid touching it. + For Athena we truncate those tables only on "replace" write disposition. + """ + pass diff --git a/dlt/common/destination/dataset.py b/dlt/common/destination/dataset.py new file mode 100644 index 0000000000..cac95225c9 --- /dev/null +++ b/dlt/common/destination/dataset.py @@ -0,0 +1,146 @@ +from typing import ( + Optional, + Sequence, + Union, + List, + Any, + Generator, + TYPE_CHECKING, + Protocol, + Tuple, + AnyStr, + overload, +) + +from dlt.common.exceptions import MissingDependencyException +from dlt.common.schema.schema import Schema +from dlt.common.schema.typing import TTableSchemaColumns + + +if TYPE_CHECKING: + try: + from dlt.common.libs.pandas import DataFrame + from dlt.common.libs.pyarrow import Table as ArrowTable + from dlt.helpers.ibis import BaseBackend as IbisBackend + except MissingDependencyException: + DataFrame = Any + ArrowTable = Any + IbisBackend = Any +else: + DataFrame = Any + ArrowTable = Any + IbisBackend = Any + + +class SupportsReadableRelation(Protocol): + """A readable relation retrieved from a destination that supports it""" + + columns_schema: TTableSchemaColumns + """Known dlt table columns for this relation""" + + def df(self, chunk_size: int = None) -> Optional[DataFrame]: + """Fetches the results as data frame. For large queries the results may be chunked + + Fetches the results into a data frame. The default implementation uses helpers in `pandas.io.sql` to generate Pandas data frame. + This function will try to use native data frame generation for particular destination. For `BigQuery`: `QueryJob.to_dataframe` is used. + For `duckdb`: `DuckDBPyConnection.df' + + Args: + chunk_size (int, optional): Will chunk the results into several data frames. Defaults to None + **kwargs (Any): Additional parameters which will be passed to native data frame generation function. + + Returns: + Optional[DataFrame]: A data frame with query results. If chunk_size > 0, None will be returned if there is no more data in results + """ + ... + + # accessing data + def arrow(self, chunk_size: int = None) -> Optional[ArrowTable]: + """fetch arrow table of first 'chunk_size' items""" + ... + + def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]: + """iterate over data frames tables of 'chunk_size' items""" + ... + + def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]: + """iterate over arrow tables of 'chunk_size' items""" + ... + + def fetchall(self) -> List[Tuple[Any, ...]]: + """fetch all items as list of python tuples""" + ... + + def fetchmany(self, chunk_size: int) -> List[Tuple[Any, ...]]: + """fetch first 'chunk_size' items as list of python tuples""" + ... + + def iter_fetch(self, chunk_size: int) -> Generator[List[Tuple[Any, ...]], Any, Any]: + """iterate in lists of python tuples in 'chunk_size' chunks""" + ... + + def fetchone(self) -> Optional[Tuple[Any, ...]]: + """fetch first item as python tuple""" + ... + + # modifying access parameters + def limit(self, limit: int, **kwargs: Any) -> "SupportsReadableRelation": + """limit the result to 'limit' items""" + ... + + def head(self, limit: int = 5) -> "SupportsReadableRelation": + """limit the result to 5 items by default""" + ... + + def select(self, *columns: str) -> "SupportsReadableRelation": + """set which columns will be selected""" + ... + + @overload + def __getitem__(self, column: str) -> "SupportsReadableRelation": ... + + @overload + def __getitem__(self, columns: Sequence[str]) -> "SupportsReadableRelation": ... + + def __getitem__(self, columns: Union[str, Sequence[str]]) -> "SupportsReadableRelation": + """set which columns will be selected""" + ... + + def __getattr__(self, attr: str) -> Any: + """get an attribute of the relation""" + ... + + def __copy__(self) -> "SupportsReadableRelation": + """create a copy of the relation object""" + ... + + +class DBApiCursor(SupportsReadableRelation): + """Protocol for DBAPI cursor""" + + description: Tuple[Any, ...] + + native_cursor: "DBApiCursor" + """Cursor implementation native to current destination""" + + def execute(self, query: AnyStr, *args: Any, **kwargs: Any) -> None: ... + def close(self) -> None: ... + + +class SupportsReadableDataset(Protocol): + """A readable dataset retrieved from a destination, has support for creating readable relations for a query or table""" + + @property + def schema(self) -> Schema: ... + + def __call__(self, query: Any) -> SupportsReadableRelation: ... + + def __getitem__(self, table: str) -> SupportsReadableRelation: ... + + def __getattr__(self, table: str) -> SupportsReadableRelation: ... + + def ibis(self) -> IbisBackend: ... + + def row_counts( + self, *, data_tables: bool = True, dlt_tables: bool = False, table_names: List[str] = None + ) -> SupportsReadableRelation: ... diff --git a/dlt/common/destination/exceptions.py b/dlt/common/destination/exceptions.py index 50796998ad..b392c1c8d2 100644 --- a/dlt/common/destination/exceptions.py +++ b/dlt/common/destination/exceptions.py @@ -18,9 +18,12 @@ def __init__(self, destination_module: str) -> None: class InvalidDestinationReference(DestinationException): - def __init__(self, destination_module: Any) -> None: - self.destination_module = destination_module - msg = f"Destination {destination_module} is not a valid destination module." + def __init__(self, refs: Any) -> None: + self.refs = refs + msg = ( + f"None of supplied destination refs: {refs} can be found in registry or imported as" + " Python type." + ) super().__init__(msg) diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 827034ddca..feb00f6a73 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -1,791 +1,45 @@ from abc import ABC, abstractmethod -import dataclasses from importlib import import_module -from types import TracebackType from typing import ( Callable, ClassVar, + List, Optional, - NamedTuple, - Literal, - Sequence, - Iterable, Type, Union, - List, - ContextManager, Dict, Any, TypeVar, Generic, - Generator, - TYPE_CHECKING, - Protocol, - Tuple, - AnyStr, - overload, ) -from typing_extensions import Annotated -import datetime # noqa: 251 +from typing_extensions import TypeAlias import inspect -from dlt.common import logger, pendulum - -from dlt.common.configuration.specs.base_configuration import extract_inner_hint -from dlt.common.destination.typing import PreparedTableSchema -from dlt.common.destination.utils import verify_schema_capabilities, verify_supported_data_types -from dlt.common.exceptions import TerminalException -from dlt.common.metrics import LoadJobMetrics +from dlt.common import logger +from dlt.common.configuration.container import Container +from dlt.common.configuration.specs import PluggableRunContext +from dlt.common.configuration.specs.pluggable_run_context import SupportsRunContext from dlt.common.normalizers.naming import NamingConvention -from dlt.common.schema.typing import TTableSchemaColumns - -from dlt.common.schema import Schema, TSchemaTables, TTableSchema -from dlt.common.schema.typing import ( - C_DLT_LOAD_ID, - TLoaderReplaceStrategy, -) -from dlt.common.schema.utils import fill_hints_from_parent_and_clone_table - -from dlt.common.configuration import configspec, resolve_configuration, known_sections, NotResolved -from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration +from dlt.common.configuration import resolve_configuration, known_sections from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.destination.exceptions import ( InvalidDestinationReference, UnknownDestinationModule, - DestinationSchemaTampered, - DestinationTransientException, ) -from dlt.common.schema.exceptions import UnknownTableException -from dlt.common.storages import FileStorage -from dlt.common.storages.load_storage import ParsedLoadJobFileName -from dlt.common.storages.load_package import LoadJobInfo, TPipelineStateDoc -from dlt.common.exceptions import MissingDependencyException -from dlt.common.typing import is_optional_type +from dlt.common.destination.client import DestinationClientConfiguration, JobClientBase +from dlt.common.runtime.run_context import RunContext +from dlt.common.schema.schema import Schema TDestinationConfig = TypeVar("TDestinationConfig", bound="DestinationClientConfiguration") TDestinationClient = TypeVar("TDestinationClient", bound="JobClientBase") -TDestinationDwhClient = TypeVar("TDestinationDwhClient", bound="DestinationClientDwhConfiguration") -TDatasetType = Literal["auto", "default", "ibis"] - - -DEFAULT_FILE_LAYOUT = "{table_name}/{load_id}.{file_id}.{ext}" - -if TYPE_CHECKING: - try: - from dlt.common.libs.pandas import DataFrame - from dlt.common.libs.pyarrow import Table as ArrowTable - from dlt.helpers.ibis import BaseBackend as IbisBackend - except MissingDependencyException: - DataFrame = Any - ArrowTable = Any - IbisBackend = Any -else: - DataFrame = Any - ArrowTable = Any - IbisBackend = Any - - -class StorageSchemaInfo(NamedTuple): - version_hash: str - schema_name: str - version: int - engine_version: str - inserted_at: datetime.datetime - schema: str - - @classmethod - def from_normalized_mapping( - cls, normalized_doc: Dict[str, Any], naming_convention: NamingConvention - ) -> "StorageSchemaInfo": - """Instantiate this class from mapping where keys are normalized according to given naming convention - - Args: - normalized_doc: Mapping with normalized keys (e.g. {Version: ..., SchemaName: ...}) - naming_convention: Naming convention that was used to normalize keys - - Returns: - StorageSchemaInfo: Instance of this class - """ - return cls( - version_hash=normalized_doc[naming_convention.normalize_identifier("version_hash")], - schema_name=normalized_doc[naming_convention.normalize_identifier("schema_name")], - version=normalized_doc[naming_convention.normalize_identifier("version")], - engine_version=normalized_doc[naming_convention.normalize_identifier("engine_version")], - inserted_at=normalized_doc[naming_convention.normalize_identifier("inserted_at")], - schema=normalized_doc[naming_convention.normalize_identifier("schema")], - ) - - def to_normalized_mapping(self, naming_convention: NamingConvention) -> Dict[str, Any]: - """Convert this instance to mapping where keys are normalized according to given naming convention - - Args: - naming_convention: Naming convention that should be used to normalize keys - - Returns: - Dict[str, Any]: Mapping with normalized keys (e.g. {Version: ..., SchemaName: ...}) - """ - return { - naming_convention.normalize_identifier(key): value - for key, value in self._asdict().items() - } - - -@dataclasses.dataclass -class StateInfo: - version: int - engine_version: int - pipeline_name: str - state: str - created_at: datetime.datetime - version_hash: Optional[str] = None - _dlt_load_id: Optional[str] = None - - def as_doc(self) -> TPipelineStateDoc: - doc: TPipelineStateDoc = dataclasses.asdict(self) # type: ignore[assignment] - if self._dlt_load_id is None: - doc.pop(C_DLT_LOAD_ID) # type: ignore[misc] - if self.version_hash is None: - doc.pop("version_hash") - return doc - - @classmethod - def from_normalized_mapping( - cls, normalized_doc: Dict[str, Any], naming_convention: NamingConvention - ) -> "StateInfo": - """Instantiate this class from mapping where keys are normalized according to given naming convention - - Args: - normalized_doc: Mapping with normalized keys (e.g. {Version: ..., PipelineName: ...}) - naming_convention: Naming convention that was used to normalize keys - - Returns: - StateInfo: Instance of this class - """ - return cls( - version=normalized_doc[naming_convention.normalize_identifier("version")], - engine_version=normalized_doc[naming_convention.normalize_identifier("engine_version")], - pipeline_name=normalized_doc[naming_convention.normalize_identifier("pipeline_name")], - state=normalized_doc[naming_convention.normalize_identifier("state")], - created_at=normalized_doc[naming_convention.normalize_identifier("created_at")], - version_hash=normalized_doc.get(naming_convention.normalize_identifier("version_hash")), - _dlt_load_id=normalized_doc.get(naming_convention.normalize_identifier(C_DLT_LOAD_ID)), - ) - - -@configspec -class DestinationClientConfiguration(BaseConfiguration): - destination_type: Annotated[str, NotResolved()] = dataclasses.field( - default=None, init=False, repr=False, compare=False - ) # which destination to load data to - credentials: Optional[CredentialsConfiguration] = None - destination_name: Optional[str] = ( - None # name of the destination, if not set, destination_type is used - ) - environment: Optional[str] = None - - def fingerprint(self) -> str: - """Returns a destination fingerprint which is a hash of selected configuration fields. ie. host in case of connection string""" - return "" - - def __str__(self) -> str: - """Return displayable destination location""" - return str(self.credentials) - - def on_resolved(self) -> None: - self.destination_name = self.destination_name or self.destination_type - - @classmethod - def credentials_type( - cls, config: "DestinationClientConfiguration" = None - ) -> Type[CredentialsConfiguration]: - """Figure out credentials type, using hint resolvers for dynamic types - - For correct type resolution of filesystem, config should have bucket_url populated - """ - key = "credentials" - type_ = cls.get_resolvable_fields()[key] - if key in cls.__hint_resolvers__ and config is not None: - try: - # Type hint for this field is created dynamically - type_ = cls.__hint_resolvers__[key](config) - except Exception: - # we suppress failed hint resolutions - pass - return extract_inner_hint(type_) - - -@configspec -class DestinationClientDwhConfiguration(DestinationClientConfiguration): - """Configuration of a destination that supports datasets/schemas""" - - dataset_name: Annotated[str, NotResolved()] = dataclasses.field( - default=None, init=False, repr=False, compare=False - ) # dataset cannot be resolved - """dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix""" - default_schema_name: Annotated[Optional[str], NotResolved()] = dataclasses.field( - default=None, init=False, repr=False, compare=False - ) - """name of default schema to be used to name effective dataset to load data to""" - replace_strategy: TLoaderReplaceStrategy = "truncate-and-insert" - """How to handle replace disposition for this destination, can be classic or staging""" - staging_dataset_name_layout: str = "%s_staging" - """Layout for staging dataset, where %s is replaced with dataset name. placeholder is optional""" - enable_dataset_name_normalization: bool = True - """Whether to normalize the dataset name. Affects staging dataset as well.""" - - def _bind_dataset_name( - self: TDestinationDwhClient, dataset_name: str, default_schema_name: str = None - ) -> TDestinationDwhClient: - """Binds the dataset and default schema name to the configuration - - This method is intended to be used internally. - """ - self.dataset_name = dataset_name - self.default_schema_name = default_schema_name - return self - - def normalize_dataset_name(self, schema: Schema) -> str: - """Builds full db dataset (schema) name out of configured dataset name and schema name: {dataset_name}_{schema.name}. The resulting name is normalized. - - If default schema name is None or equals schema.name, the schema suffix is skipped. - """ - dataset_name = self._make_dataset_name(schema.name) - if not dataset_name: - return dataset_name - else: - return ( - schema.naming.normalize_table_identifier(dataset_name) - if self.enable_dataset_name_normalization - else dataset_name - ) - - def normalize_staging_dataset_name(self, schema: Schema) -> str: - """Builds staging dataset name out of dataset_name and staging_dataset_name_layout.""" - if "%s" in self.staging_dataset_name_layout: - # staging dataset name is never empty, otherwise table names must clash - dataset_name = self._make_dataset_name(schema.name) - # fill the placeholder - dataset_name = self.staging_dataset_name_layout % (dataset_name or "") - else: - # no placeholder, then layout is a full name. so you can have a single staging dataset - dataset_name = self.staging_dataset_name_layout - - return ( - schema.naming.normalize_table_identifier(dataset_name) - if self.enable_dataset_name_normalization - else dataset_name - ) - - @classmethod - def needs_dataset_name(cls) -> bool: - """Checks if configuration requires dataset name to be present. Empty datasets are allowed - ie. for schema-less destinations like weaviate or clickhouse - """ - fields = cls.get_resolvable_fields() - dataset_name_type = fields["dataset_name"] - return not is_optional_type(dataset_name_type) - - def _make_dataset_name(self, schema_name: str) -> str: - if not schema_name: - raise ValueError("schema_name is None or empty") - - # if default schema is None then suffix is not added - if self.default_schema_name is not None and schema_name != self.default_schema_name: - return (self.dataset_name or "") + "_" + schema_name - return self.dataset_name - - -@configspec -class DestinationClientStagingConfiguration(DestinationClientDwhConfiguration): - """Configuration of a staging destination, able to store files with desired `layout` at `bucket_url`. - - Also supports datasets and can act as standalone destination. - """ - - as_staging_destination: bool = False - bucket_url: str = None - # layout of the destination files - layout: str = DEFAULT_FILE_LAYOUT - - -@configspec -class DestinationClientDwhWithStagingConfiguration(DestinationClientDwhConfiguration): - """Configuration of a destination that can take data from staging destination""" - - staging_config: Optional[DestinationClientStagingConfiguration] = None - """configuration of the staging, if present, injected at runtime""" - truncate_tables_on_staging_destination_before_load: bool = True - """If dlt should truncate the tables on staging destination before loading data.""" - - -TLoadJobState = Literal["ready", "running", "failed", "retry", "completed"] - - -class LoadJob(ABC): - """ - A stateful load job, represents one job file - """ - - def __init__(self, file_path: str) -> None: - self._file_path = file_path - self._file_name = FileStorage.get_file_name_from_file_path(file_path) - # NOTE: we only accept a full filepath in the constructor - assert self._file_name != self._file_path - self._parsed_file_name = ParsedLoadJobFileName.parse(self._file_name) - self._started_at: pendulum.DateTime = None - self._finished_at: pendulum.DateTime = None - - def job_id(self) -> str: - """The job id that is derived from the file name and does not changes during job lifecycle""" - return self._parsed_file_name.job_id() - - def file_name(self) -> str: - """A name of the job file""" - return self._file_name - - def job_file_info(self) -> ParsedLoadJobFileName: - return self._parsed_file_name - - @abstractmethod - def state(self) -> TLoadJobState: - """Returns current state. Should poll external resource if necessary.""" - pass - - @abstractmethod - def exception(self) -> str: - """The exception associated with failed or retry states""" - pass - - def metrics(self) -> Optional[LoadJobMetrics]: - """Returns job execution metrics""" - return LoadJobMetrics( - self._parsed_file_name.job_id(), - self._file_path, - self._parsed_file_name.table_name, - self._started_at, - self._finished_at, - self.state(), - None, - ) - - -class RunnableLoadJob(LoadJob, ABC): - """Represents a runnable job that loads a single file - - Each job starts in "running" state and ends in one of terminal states: "retry", "failed" or "completed". - Each job is uniquely identified by a file name. The file is guaranteed to exist in "running" state. In terminal state, the file may not be present. - In "running" state, the loader component periodically gets the state via `status()` method. When terminal state is reached, load job is discarded and not called again. - `exception` method is called to get error information in "failed" and "retry" states. - - The `__init__` method is responsible to put the Job in "running" state. It may raise `LoadClientTerminalException` and `LoadClientTransientException` to - immediately transition job into "failed" or "retry" state respectively. - """ - - def __init__(self, file_path: str) -> None: - """ - File name is also a job id (or job id is deterministically derived) so it must be globally unique - """ - # ensure file name - super().__init__(file_path) - self._state: TLoadJobState = "ready" - self._exception: BaseException = None - - # variables needed by most jobs, set by the loader in set_run_vars - self._schema: Schema = None - self._load_table: PreparedTableSchema = None - self._load_id: str = None - self._job_client: "JobClientBase" = None - - def set_run_vars(self, load_id: str, schema: Schema, load_table: PreparedTableSchema) -> None: - """ - called by the loader right before the job is run - """ - self._load_id = load_id - self._schema = schema - self._load_table = load_table - - @property - def load_table_name(self) -> str: - return self._load_table["name"] - - def run_managed( - self, - job_client: "JobClientBase", - ) -> None: - """ - wrapper around the user implemented run method - """ - from dlt.common.runtime import signals - - # only jobs that are not running or have not reached a final state - # may be started - assert self._state in ("ready", "retry") - self._job_client = job_client - - # filepath is now moved to running - try: - self._state = "running" - self._started_at = pendulum.now() - self._job_client.prepare_load_job_execution(self) - self.run() - self._state = "completed" - except (TerminalException, AssertionError) as e: - self._state = "failed" - self._exception = e - logger.exception(f"Terminal exception in job {self.job_id()} in file {self._file_path}") - except (DestinationTransientException, Exception) as e: - self._state = "retry" - self._exception = e - logger.exception( - f"Transient exception in job {self.job_id()} in file {self._file_path}" - ) - finally: - self._finished_at = pendulum.now() - # sanity check - assert self._state in ("completed", "retry", "failed") - if self._state != "retry": - # wake up waiting threads - signals.wake_all() - - @abstractmethod - def run(self) -> None: - """ - run the actual job, this will be executed on a thread and should be implemented by the user - exception will be handled outside of this function - """ - raise NotImplementedError() - - def state(self) -> TLoadJobState: - """Returns current state. Should poll external resource if necessary.""" - return self._state - - def exception(self) -> str: - """The exception associated with failed or retry states""" - return str(self._exception) - - -class FollowupJobRequest: - """Base class for follow up jobs that should be created""" - - @abstractmethod - def new_file_path(self) -> str: - """Path to a newly created temporary job file. If empty, no followup job should be created""" - pass - - -class HasFollowupJobs: - """Adds a trait that allows to create single or table chain followup jobs""" - - def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJobRequest]: - """Return list of jobs requests for jobs that should be created. `final_state` is state to which this job transits""" - return [] - - -class SupportsReadableRelation(Protocol): - """A readable relation retrieved from a destination that supports it""" - - columns_schema: TTableSchemaColumns - """Known dlt table columns for this relation""" - - def df(self, chunk_size: int = None) -> Optional[DataFrame]: - """Fetches the results as data frame. For large queries the results may be chunked - - Fetches the results into a data frame. The default implementation uses helpers in `pandas.io.sql` to generate Pandas data frame. - This function will try to use native data frame generation for particular destination. For `BigQuery`: `QueryJob.to_dataframe` is used. - For `duckdb`: `DuckDBPyConnection.df' - - Args: - chunk_size (int, optional): Will chunk the results into several data frames. Defaults to None - **kwargs (Any): Additional parameters which will be passed to native data frame generation function. - - Returns: - Optional[DataFrame]: A data frame with query results. If chunk_size > 0, None will be returned if there is no more data in results - """ - ... - - # accessing data - def arrow(self, chunk_size: int = None) -> Optional[ArrowTable]: - """fetch arrow table of first 'chunk_size' items""" - ... - - def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]: - """iterate over data frames tables of 'chunk_size' items""" - ... - - def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]: - """iterate over arrow tables of 'chunk_size' items""" - ... - - def fetchall(self) -> List[Tuple[Any, ...]]: - """fetch all items as list of python tuples""" - ... - - def fetchmany(self, chunk_size: int) -> List[Tuple[Any, ...]]: - """fetch first 'chunk_size' items as list of python tuples""" - ... - - def iter_fetch(self, chunk_size: int) -> Generator[List[Tuple[Any, ...]], Any, Any]: - """iterate in lists of python tuples in 'chunk_size' chunks""" - ... - - def fetchone(self) -> Optional[Tuple[Any, ...]]: - """fetch first item as python tuple""" - ... - - # modifying access parameters - def limit(self, limit: int, **kwargs: Any) -> "SupportsReadableRelation": - """limit the result to 'limit' items""" - ... - - def head(self, limit: int = 5) -> "SupportsReadableRelation": - """limit the result to 5 items by default""" - ... - - def select(self, *columns: str) -> "SupportsReadableRelation": - """set which columns will be selected""" - ... - - @overload - def __getitem__(self, column: str) -> "SupportsReadableRelation": ... - - @overload - def __getitem__(self, columns: Sequence[str]) -> "SupportsReadableRelation": ... - - def __getitem__(self, columns: Union[str, Sequence[str]]) -> "SupportsReadableRelation": - """set which columns will be selected""" - ... - - def __getattr__(self, attr: str) -> Any: - """get an attribute of the relation""" - ... - - def __copy__(self) -> "SupportsReadableRelation": - """create a copy of the relation object""" - ... - - -class DBApiCursor(SupportsReadableRelation): - """Protocol for DBAPI cursor""" - - description: Tuple[Any, ...] - - native_cursor: "DBApiCursor" - """Cursor implementation native to current destination""" - - def execute(self, query: AnyStr, *args: Any, **kwargs: Any) -> None: ... - def close(self) -> None: ... - - -class SupportsReadableDataset(Protocol): - """A readable dataset retrieved from a destination, has support for creating readable relations for a query or table""" - - @property - def schema(self) -> Schema: ... - - def __call__(self, query: Any) -> SupportsReadableRelation: ... - - def __getitem__(self, table: str) -> SupportsReadableRelation: ... - - def __getattr__(self, table: str) -> SupportsReadableRelation: ... - - def ibis(self) -> IbisBackend: ... - - def row_counts( - self, *, data_tables: bool = True, dlt_tables: bool = False, table_names: List[str] = None - ) -> SupportsReadableRelation: ... - - -class JobClientBase(ABC): - def __init__( - self, - schema: Schema, - config: DestinationClientConfiguration, - capabilities: DestinationCapabilitiesContext, - ) -> None: - self.schema = schema - self.config = config - self.capabilities = capabilities - - @abstractmethod - def initialize_storage(self, truncate_tables: Optional[Iterable[str]] = None) -> None: - """Prepares storage to be used ie. creates database schema or file system folder. Truncates requested tables.""" - pass - - @abstractmethod - def is_storage_initialized(self) -> bool: - """Returns if storage is ready to be read/written.""" - pass - - @abstractmethod - def drop_storage(self) -> None: - """Brings storage back into not initialized state. Typically data in storage is destroyed.""" - pass - - def verify_schema( - self, only_tables: Iterable[str] = None, new_jobs: Iterable[ParsedLoadJobFileName] = None - ) -> List[PreparedTableSchema]: - """Verifies schema before loading, returns a list of verified loaded tables.""" - if exceptions := verify_schema_capabilities( - self.schema, - self.capabilities, - self.config.destination_type, - warnings=False, - ): - for exception in exceptions: - logger.error(str(exception)) - raise exceptions[0] - - prepared_tables = [ - self.prepare_load_table(table_name) - for table_name in set( - list(only_tables or []) + self.schema.data_table_names(seen_data_only=True) - ) - ] - if exceptions := verify_supported_data_types( - prepared_tables, - new_jobs, - self.capabilities, - self.config.destination_type, - warnings=False, - ): - for exception in exceptions: - logger.error(str(exception)) - raise exceptions[0] - return prepared_tables - - def update_stored_schema( - self, - only_tables: Iterable[str] = None, - expected_update: TSchemaTables = None, - ) -> Optional[TSchemaTables]: - """Updates storage to the current schema. - - Implementations should not assume that `expected_update` is the exact difference between destination state and the self.schema. This is only the case if - destination has single writer and no other processes modify the schema. - - Args: - only_tables (Sequence[str], optional): Updates only listed tables. Defaults to None. - expected_update (TSchemaTables, optional): Update that is expected to be applied to the destination - Returns: - Optional[TSchemaTables]: Returns an update that was applied at the destination. - """ - # make sure that schema being saved was not modified from the moment it was loaded from storage - version_hash = self.schema.version_hash - if self.schema.is_modified: - raise DestinationSchemaTampered( - self.schema.name, version_hash, self.schema.stored_version_hash - ) - return expected_update - - def prepare_load_table(self, table_name: str) -> PreparedTableSchema: - """Prepares a table schema to be loaded by filling missing hints and doing other modifications requires by given destination.""" - try: - return fill_hints_from_parent_and_clone_table(self.schema.tables, self.schema.tables[table_name]) # type: ignore[return-value] - - except KeyError: - raise UnknownTableException(self.schema.name, table_name) - - @abstractmethod - def create_load_job( - self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False - ) -> LoadJob: - """Creates a load job for a particular `table` with content in `file_path`. Table is already prepared to be loaded.""" - pass - - def prepare_load_job_execution( # noqa: B027, optional override - self, job: RunnableLoadJob - ) -> None: - """Prepare the connected job client for the execution of a load job (used for query tags in sql clients)""" - pass - - def should_truncate_table_before_load(self, table_name: str) -> bool: - return self.prepare_load_table(table_name)["write_disposition"] == "replace" - - def create_table_chain_completed_followup_jobs( - self, - table_chain: Sequence[PreparedTableSchema], - completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[FollowupJobRequest]: - """Creates a list of followup jobs that should be executed after a table chain is completed. Tables are already prepared to be loaded.""" - return [] - - @abstractmethod - def complete_load(self, load_id: str) -> None: - """Marks the load package with `load_id` as completed in the destination. Before such commit is done, the data with `load_id` is invalid.""" - pass - - @abstractmethod - def __enter__(self) -> "JobClientBase": - pass - - @abstractmethod - def __exit__( - self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType - ) -> None: - pass - - -class WithStateSync(ABC): - @abstractmethod - def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]: - """ - Retrieves newest schema with given name from destination storage - If no name is provided, the newest schema found is retrieved. - """ - pass - - @abstractmethod - def get_stored_schema_by_hash(self, version_hash: str) -> StorageSchemaInfo: - """retrieves the stored schema by hash""" - pass - - @abstractmethod - def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: - """Loads compressed state from destination storage""" - pass - - -class WithStagingDataset(ABC): - """Adds capability to use staging dataset and request it from the loader""" - - @abstractmethod - def should_load_data_to_staging_dataset(self, table_name: str) -> bool: - return False - - @abstractmethod - def with_staging_dataset(self) -> ContextManager["JobClientBase"]: - """Executes job client methods on staging dataset""" - return self # type: ignore - - -class SupportsStagingDestination(ABC): - """Adds capability to support a staging destination for the load""" - - def should_load_data_to_staging_dataset_on_staging_destination(self, table_name: str) -> bool: - """If set to True, and staging destination is configured, the data will be loaded to staging dataset on staging destination - instead of a regular dataset on staging destination. Currently it is used by Athena Iceberg which uses staging dataset - on staging destination to copy data to iceberg tables stored on regular dataset on staging destination. - The default is to load data to regular dataset on staging destination from where warehouses like Snowflake (that have their - own storage) will copy data. - """ - return False - - @abstractmethod - def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: - """If set to True, data in `table` will be truncated on staging destination (regular dataset). This is the default behavior which - can be changed with a config flag. - For Athena + Iceberg this setting is always False - Athena uses regular dataset to store Iceberg tables and we avoid touching it. - For Athena we truncate those tables only on "replace" write disposition. - """ - pass +AnyDestination: TypeAlias = "Destination[DestinationClientConfiguration, JobClientBase]" +AnyDestination_CO: TypeAlias = "Destination[Any, Any]" # TODO: type Destination properly -TDestinationReferenceArg = Union[ - str, "Destination[Any, Any]", Callable[..., "Destination[Any, Any]"], None -] +TDestinationReferenceArg = Union[str, AnyDestination_CO, Callable[..., AnyDestination_CO], None] class Destination(ABC, Generic[TDestinationConfig, TDestinationClient]): @@ -793,6 +47,10 @@ class Destination(ABC, Generic[TDestinationConfig, TDestinationClient]): with credentials and other config params. """ + DESTINATIONS: ClassVar[Dict[str, Type[AnyDestination]]] = {} + """A registry of all the decorated destinations""" + CONTEXT: ClassVar[SupportsRunContext] = None + config_params: Dict[str, Any] """Explicit config params, overriding any injected or default values.""" caps_params: Dict[str, Any] @@ -802,7 +60,8 @@ def __init__(self, **kwargs: Any) -> None: # Create initial unresolved destination config # Argument defaults are filtered out here because we only want arguments passed explicitly # to supersede config from the environment or pipeline args - sig = inspect.signature(self.__class__.__init__) + # __orig_base__ tells where the __init__ of interest is, in case class is derived + sig = inspect.signature(getattr(self.__class__, "__orig_base__", self.__class__).__init__) params = sig.parameters # get available args @@ -943,7 +202,7 @@ def adjust_capabilities( @staticmethod def to_name(ref: TDestinationReferenceArg) -> str: if ref is None: - raise InvalidDestinationReference(ref) + raise InvalidDestinationReference([]) if isinstance(ref, str): return ref.rsplit(".", 1)[-1] if callable(ref): @@ -958,18 +217,60 @@ def normalize_type(destination_type: str) -> str: # the next two lines shorten the dlt internal destination paths to dlt.destinations. name = Destination.to_name(destination_type) destination_type = destination_type.replace( - f"dlt.destinations.impl.{name}.factory.", "dlt.destinations." + f".destinations.impl.{name}.factory.", ".destinations." ) return destination_type + @classmethod + def register(cls, destination_name: str) -> None: + cls.CONTEXT = Container()[PluggableRunContext].context + ref = f"{cls.CONTEXT.name}.{destination_name}" + if ref in cls.DESTINATIONS: + logger.info( + f"A destination with ref {ref} is already registered and will be overwritten" + ) + cls.DESTINATIONS[ref] = cls # type: ignore[assignment] + @staticmethod + def to_fully_qualified_refs(ref: str) -> List[str]: + """Converts ref into fully qualified form, return one or more alternatives for shorthand notations. + Run context is injected if needed. Following formats are recognized + - context_name.'destinations'.name (fully qualified) + - 'destinations'.name + - name + NOTE: the last component of destination type serves as destination name if not explicitly specified + """ + ref_split = ref.split(".") + ref_parts = len(ref_split) + if ref_parts < 2 or (ref_parts == 2 and ref_split[1] == known_sections.DESTINATIONS): + # context name is needed + refs = [] + run_names = [Container()[PluggableRunContext].context.name] + # always look in default run context + if run_names[0] != RunContext.CONTEXT_NAME: + run_names.append(RunContext.CONTEXT_NAME) + for run_name in run_names: + if ref_parts == 1: + # ref is: name + refs.append(f"{run_name}.{known_sections.DESTINATIONS}.{ref}") + else: + # ref is: `destinations`.name` + refs.append(f"{run_name}.{ref}") + return refs + if len(ref_split) == 3 and ref_split[1] == known_sections.DESTINATIONS: + return [ref] + + return [] + + @classmethod def from_reference( + cls, ref: TDestinationReferenceArg, credentials: Optional[Any] = None, destination_name: Optional[str] = None, environment: Optional[str] = None, **kwargs: Any, - ) -> Optional["Destination[DestinationClientConfiguration, JobClientBase]"]: + ) -> Optional[AnyDestination]: """Instantiate destination from str reference. The ref can be a destination name or import path pointing to a destination class (e.g. `dlt.destinations.postgres`) """ @@ -988,20 +289,36 @@ def from_reference( " Destination instance, these values will be ignored." ) return ref + if not isinstance(ref, str): raise InvalidDestinationReference(ref) - try: - module_path, attr_name = Destination.normalize_type(ref).rsplit(".", 1) - dest_module = import_module(module_path) - except ModuleNotFoundError as e: - raise UnknownDestinationModule(ref) from e - try: - factory: Type[Destination[DestinationClientConfiguration, JobClientBase]] = getattr( - dest_module, attr_name - ) - except AttributeError as e: - raise UnknownDestinationModule(ref) from e + # resolve ref + refs = cls.to_fully_qualified_refs(ref) + factory: Type[AnyDestination] = None + + for ref_ in refs: + if factory := cls.DESTINATIONS.get(ref_): + break + + # no reference found, try to import default module + if not factory: + # try ref, normalized refs and ref without the context name + refs.extend(set([r.split(".", 1)[1] for r in refs])) + if "." in ref and ref not in refs: + refs = [ref] + refs + for possible_type in refs: + try: + module_path, attr_name = possible_type.rsplit(".", 1) + dest_module = import_module(module_path) + except ModuleNotFoundError as e: + raise UnknownDestinationModule(ref) from e + try: + factory = getattr(dest_module, attr_name) + except AttributeError as e: + raise UnknownDestinationModule(ref) from e + break + if credentials: kwargs["credentials"] = credentials if destination_name: @@ -1014,6 +331,3 @@ def from_reference( except Exception as e: raise InvalidDestinationReference(ref) from e return dest - - -AnyDestination = Destination[DestinationClientConfiguration, JobClientBase] diff --git a/dlt/common/destination/typing.py b/dlt/common/destination/typing.py index c79a2b0adc..83d91aa197 100644 --- a/dlt/common/destination/typing.py +++ b/dlt/common/destination/typing.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Literal, Optional from dlt.common.schema.typing import ( _TTableSchemaBase, @@ -6,6 +6,8 @@ TTableReferenceParam, ) +TDatasetType = Literal["auto", "default", "ibis"] + class PreparedTableSchema(_TTableSchemaBase, total=False): """Table schema with all hints prepared to be loaded""" diff --git a/dlt/common/reflection/function_visitor.py b/dlt/common/reflection/function_visitor.py deleted file mode 100644 index 6cb6016a7f..0000000000 --- a/dlt/common/reflection/function_visitor.py +++ /dev/null @@ -1,14 +0,0 @@ -import ast -from ast import NodeVisitor -from typing import Any - - -class FunctionVisitor(NodeVisitor): - def __init__(self, source: str): - self.source = source - self.top_func: ast.FunctionDef = None - - def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: - if not self.top_func: - self.top_func = node - super().generic_visit(node) diff --git a/dlt/common/reflection/spec.py b/dlt/common/reflection/spec.py index 00ed6e6727..17915ff1e4 100644 --- a/dlt/common/reflection/spec.py +++ b/dlt/common/reflection/spec.py @@ -20,7 +20,7 @@ _SLEEPING_CAT_SPLIT = re.compile("[^.^_]+") -def _get_spec_name_from_f(f: AnyFun) -> str: +def get_spec_name_from_f(f: AnyFun, kind: str = "Configuration") -> str: func_name = get_callable_name(f, "__qualname__").replace( ".", "" ) # func qual name contains position in the module, separated by dots @@ -28,7 +28,7 @@ def _get_spec_name_from_f(f: AnyFun) -> str: def _first_up(s: str) -> str: return s[0].upper() + s[1:] - return "".join(map(_first_up, _SLEEPING_CAT_SPLIT.findall(func_name))) + "Configuration" + return "".join(map(_first_up, _SLEEPING_CAT_SPLIT.findall(func_name))) + kind def spec_from_signature( @@ -48,7 +48,7 @@ def spec_from_signature( Return value is a tuple of SPEC and SPEC fields created from a `sig`. """ - name = _get_spec_name_from_f(f) + name = get_spec_name_from_f(f) module = inspect.getmodule(f) base_fields = base.get_resolvable_fields() diff --git a/dlt/destinations/dataset/dataset.py b/dlt/destinations/dataset/dataset.py index fc55393a60..fcf8b98a84 100644 --- a/dlt/destinations/dataset/dataset.py +++ b/dlt/destinations/dataset/dataset.py @@ -4,20 +4,15 @@ from dlt.common.exceptions import MissingDependencyException -from dlt.common.destination.reference import ( - SupportsReadableRelation, - SupportsReadableDataset, - TDestinationReferenceArg, - Destination, - JobClientBase, - WithStateSync, -) +from dlt.common.destination.reference import TDestinationReferenceArg, Destination +from dlt.common.destination.client import JobClientBase, WithStateSync +from dlt.common.destination.dataset import SupportsReadableRelation, SupportsReadableDataset +from dlt.common.destination.typing import TDatasetType from dlt.destinations.sql_client import SqlClientBase, WithSqlClient from dlt.common.schema import Schema from dlt.destinations.dataset.relation import ReadableDBAPIRelation from dlt.destinations.dataset.utils import get_destination_clients -from dlt.common.destination.reference import TDatasetType if TYPE_CHECKING: try: diff --git a/dlt/destinations/dataset/factory.py b/dlt/destinations/dataset/factory.py index 8ea0ddf7a1..94b3480ce3 100644 --- a/dlt/destinations/dataset/factory.py +++ b/dlt/destinations/dataset/factory.py @@ -1,13 +1,8 @@ from typing import Union - -from dlt.common.destination import AnyDestination -from dlt.common.destination.reference import ( - SupportsReadableDataset, - TDatasetType, - TDestinationReferenceArg, -) - +from dlt.common.destination import TDestinationReferenceArg +from dlt.common.destination.dataset import SupportsReadableDataset +from dlt.common.destination.typing import TDatasetType from dlt.common.schema import Schema from dlt.destinations.dataset.dataset import ReadableDBAPIDataset diff --git a/dlt/destinations/dataset/relation.py b/dlt/destinations/dataset/relation.py index 2cdb7640df..83f655b07f 100644 --- a/dlt/destinations/dataset/relation.py +++ b/dlt/destinations/dataset/relation.py @@ -3,7 +3,7 @@ from contextlib import contextmanager -from dlt.common.destination.reference import ( +from dlt.common.destination.dataset import ( SupportsReadableRelation, ) diff --git a/dlt/destinations/dataset/utils.py b/dlt/destinations/dataset/utils.py index 766fbc13ea..dee754df9a 100644 --- a/dlt/destinations/dataset/utils.py +++ b/dlt/destinations/dataset/utils.py @@ -4,16 +4,14 @@ from dlt.common.exceptions import MissingDependencyException -from dlt.common.destination import AnyDestination -from dlt.common.destination.reference import ( - Destination, +from dlt.common.destination import AnyDestination, Destination +from dlt.common.destination.client import ( JobClientBase, DestinationClientDwhConfiguration, DestinationClientStagingConfiguration, DestinationClientConfiguration, DestinationClientDwhWithStagingConfiguration, ) - from dlt.common.schema import Schema diff --git a/dlt/destinations/decorators.py b/dlt/destinations/decorators.py index c4110035b9..d1ab8a74a2 100644 --- a/dlt/destinations/decorators.py +++ b/dlt/destinations/decorators.py @@ -1,7 +1,9 @@ import functools -from typing import Any, Type, Optional, Callable, Union +import inspect +from typing import Any, Type, Optional, Callable, Union, overload from typing_extensions import Concatenate +from dlt.common.reflection.spec import get_spec_name_from_f from dlt.common.typing import AnyFun from functools import wraps @@ -12,6 +14,7 @@ from dlt.common.schema import TTableSchema from dlt.common.destination.capabilities import TLoaderParallelismStrategy +from dlt.common.utils import get_callable_name, is_inner_callable from dlt.destinations.impl.destination.factory import destination as _destination from dlt.destinations.impl.destination.configuration import ( TDestinationCallableParams, @@ -19,8 +22,27 @@ ) +@overload def destination( - func: Optional[AnyFun] = None, + func: Callable[ + Concatenate[Union[TDataItems, str], TTableSchema, TDestinationCallableParams], Any + ], + /, + loader_file_format: TLoaderFileFormat = None, + batch_size: int = 10, + name: str = None, + naming_convention: str = "direct", + skip_dlt_columns_and_tables: bool = True, + max_table_nesting: int = 0, + spec: Type[CustomDestinationClientConfiguration] = None, + max_parallel_load_jobs: Optional[int] = None, + loader_parallelism_strategy: Optional[TLoaderParallelismStrategy] = None, +) -> Callable[TDestinationCallableParams, _destination]: ... + + +@overload +def destination( + func: None = ..., /, loader_file_format: TLoaderFileFormat = None, batch_size: int = 10, @@ -34,7 +56,22 @@ def destination( ) -> Callable[ [Callable[Concatenate[Union[TDataItems, str], TTableSchema, TDestinationCallableParams], Any]], Callable[TDestinationCallableParams, _destination], -]: +]: ... + + +def destination( + func: Optional[AnyFun] = None, + /, + loader_file_format: TLoaderFileFormat = None, + batch_size: int = 10, + name: str = None, + naming_convention: str = "direct", + skip_dlt_columns_and_tables: bool = True, + max_table_nesting: int = 0, + spec: Type[CustomDestinationClientConfiguration] = None, + max_parallel_load_jobs: Optional[int] = None, + loader_parallelism_strategy: Optional[TLoaderParallelismStrategy] = None, +) -> Any: """A decorator that transforms a function that takes two positional arguments "table" and "items" and any number of keyword arguments with defaults into a callable that will create a custom destination. The function does not return anything, the keyword arguments can be configuration and secrets values. @@ -53,7 +90,7 @@ def destination( Args: batch_size: defines how many items per function call are batched together and sent as an array. If you set a batch-size of 0, instead of passing in actual dataitems, you will receive one call per load job with the path of the file as the items argument. You can then open and process that file in any way you like. loader_file_format: defines in which format files are stored in the load package before being sent to the destination function, this can be puae-jsonl or parquet. - name: defines the name of the destination that get's created by the destination decorator, defaults to the name of the function + name: defines the name of the destination that gets created by the destination decorator, defaults to the name of the function naming_convention: defines the name of the destination that gets created by the destination decorator. This controls how table and column names are normalized. The default is direct which will keep all names the same. max_nesting_level: defines how deep the normalizer will go to normalize nested fields on your data to create subtables. This overwrites any settings on your source and is set to zero to not create any nested tables by default. skip_dlt_columns_and_tables: defines wether internal tables and columns will be fed into the custom destination function. This is set to True by default. @@ -69,6 +106,40 @@ def decorator( Concatenate[Union[TDataItems, str], TTableSchema, TDestinationCallableParams], Any ] ) -> Callable[TDestinationCallableParams, _destination]: + # resolve destination name + destination_name = name or get_callable_name(destination_callable) + + # synthesize new Destination factory + class _ConcreteDestinationBase(_destination): + def __init__(self, **kwargs: Any): + super().__init__( + spec=spec, + destination_callable=destination_callable, + loader_file_format=loader_file_format, + batch_size=batch_size, + destination_name=destination_name, + naming_convention=naming_convention, + skip_dlt_columns_and_tables=skip_dlt_columns_and_tables, + max_table_nesting=max_table_nesting, + max_parallel_load_jobs=max_parallel_load_jobs, + loader_parallelism_strategy=loader_parallelism_strategy, + **kwargs, + ) + + cls_name = get_spec_name_from_f(destination_callable, kind="Destination") + module = inspect.getmodule(destination_callable) + # synthesize type + D: Type[_destination] = type( + cls_name, + (_ConcreteDestinationBase,), + {"__module__": module.__name__, "__orig_base__": _destination}, + ) + # add to the module + setattr(module, cls_name, D) + # register only standalone destinations, no inner + if not is_inner_callable(destination_callable): + D.register(destination_name=destination_name) + @wraps(destination_callable) def wrapper( *args: TDestinationCallableParams.args, **kwargs: TDestinationCallableParams.kwargs @@ -78,19 +149,7 @@ def wrapper( "Ignoring positional arguments for destination callable %s", destination_callable, ) - return _destination( - spec=spec, - destination_callable=destination_callable, - loader_file_format=loader_file_format, - batch_size=batch_size, - destination_name=name, - naming_convention=naming_convention, - skip_dlt_columns_and_tables=skip_dlt_columns_and_tables, - max_table_nesting=max_table_nesting, - max_parallel_load_jobs=max_parallel_load_jobs, - loader_parallelism_strategy=loader_parallelism_strategy, - **kwargs, # type: ignore - ) + return D(**kwargs) # type: ignore[arg-type] return wrapper @@ -98,5 +157,5 @@ def wrapper( # we're called with parens. return decorator - # we're called as @source without parens. - return decorator(func) # type: ignore + # we're called as @dlt.destination without parens. + return decorator(func) diff --git a/dlt/destinations/exceptions.py b/dlt/destinations/exceptions.py index 3b3e602b57..a001b3d5ca 100644 --- a/dlt/destinations/exceptions.py +++ b/dlt/destinations/exceptions.py @@ -6,7 +6,7 @@ DestinationUndefinedEntity, DestinationException, ) -from dlt.common.destination.reference import TLoadJobState +from dlt.common.destination.client import TLoadJobState class DatabaseException(DestinationException): diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index c7e30aaf55..979a542862 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -40,7 +40,8 @@ TSortOrder, ) from dlt.common.destination import DestinationCapabilitiesContext, PreparedTableSchema -from dlt.common.destination.reference import FollowupJobRequest, SupportsStagingDestination, LoadJob +from dlt.common.destination.client import FollowupJobRequest, SupportsStagingDestination, LoadJob +from dlt.common.destination.dataset import DBApiCursor from dlt.common.data_writers.escape import escape_hive_identifier from dlt.destinations.sql_jobs import SqlStagingCopyFollowupJob, SqlMergeFollowupJob @@ -56,7 +57,6 @@ raise_database_error, raise_open_connection_error, ) -from dlt.common.destination.reference import DBApiCursor from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs, FinalizedLoadJob from dlt.destinations.impl.athena.configuration import AthenaClientConfiguration diff --git a/dlt/destinations/impl/athena/configuration.py b/dlt/destinations/impl/athena/configuration.py index 8a0f14b4cc..32353944be 100644 --- a/dlt/destinations/impl/athena/configuration.py +++ b/dlt/destinations/impl/athena/configuration.py @@ -2,9 +2,8 @@ from typing import ClassVar, Final, List, Optional import warnings -from dlt.common import logger from dlt.common.configuration import configspec -from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.destination.client import DestinationClientDwhWithStagingConfiguration from dlt.common.configuration.specs import AwsCredentials from dlt.common.warnings import Dlt100DeprecationWarning diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 10a344f768..cd7c5ef936 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -10,7 +10,7 @@ from dlt.common import logger from dlt.common.destination import DestinationCapabilitiesContext, PreparedTableSchema -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( HasFollowupJobs, FollowupJobRequest, RunnableLoadJob, diff --git a/dlt/destinations/impl/bigquery/configuration.py b/dlt/destinations/impl/bigquery/configuration.py index 3d71b0c8ea..fd3a83a859 100644 --- a/dlt/destinations/impl/bigquery/configuration.py +++ b/dlt/destinations/impl/bigquery/configuration.py @@ -6,7 +6,7 @@ from dlt.common.configuration.specs import GcpServiceAccountCredentials from dlt.common.utils import digest128 -from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.destination.client import DestinationClientDwhWithStagingConfiguration @configspec diff --git a/dlt/destinations/impl/bigquery/sql_client.py b/dlt/destinations/impl/bigquery/sql_client.py index 6911fa5c1c..96d4fc39ad 100644 --- a/dlt/destinations/impl/bigquery/sql_client.py +++ b/dlt/destinations/impl/bigquery/sql_client.py @@ -24,7 +24,7 @@ raise_open_connection_error, ) from dlt.destinations.typing import DBApi, DBTransaction, DataFrame, ArrowTable -from dlt.common.destination.reference import DBApiCursor +from dlt.common.destination.dataset import DBApiCursor # terminal reasons as returned in BQ gRPC error response diff --git a/dlt/destinations/impl/clickhouse/clickhouse.py b/dlt/destinations/impl/clickhouse/clickhouse.py index a407e56361..38299bb258 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse.py +++ b/dlt/destinations/impl/clickhouse/clickhouse.py @@ -13,7 +13,7 @@ AwsCredentialsWithoutDefaults, ) from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( PreparedTableSchema, SupportsStagingDestination, TLoadJobState, diff --git a/dlt/destinations/impl/clickhouse/configuration.py b/dlt/destinations/impl/clickhouse/configuration.py index adcc2a3e4c..ae6005c60a 100644 --- a/dlt/destinations/impl/clickhouse/configuration.py +++ b/dlt/destinations/impl/clickhouse/configuration.py @@ -5,7 +5,7 @@ from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.configuration.specs.base_configuration import NotResolved -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( DestinationClientDwhWithStagingConfiguration, ) from dlt.common.utils import digest128 diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index 21338bd310..c0cf9d1962 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -3,7 +3,7 @@ from dlt.common.typing import TSecretStrValue from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec -from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.destination.client import DestinationClientDwhWithStagingConfiguration from dlt.common.configuration.exceptions import ConfigurationValueError diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index a83db6ec34..5b006aa62c 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -5,7 +5,7 @@ AzureServicePrincipalCredentialsWithoutDefaults, ) from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( HasFollowupJobs, FollowupJobRequest, PreparedTableSchema, diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index 16e1e73d93..98320a9de6 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -37,7 +37,7 @@ ) from dlt.destinations.typing import ArrowTable, DBApi, DBTransaction, DataFrame from dlt.destinations.impl.databricks.configuration import DatabricksCredentials -from dlt.common.destination.reference import DBApiCursor +from dlt.common.destination.dataset import DBApiCursor class DatabricksCursorImpl(DBApiCursorImpl): diff --git a/dlt/destinations/impl/dremio/configuration.py b/dlt/destinations/impl/dremio/configuration.py index 0a95c2807c..12ec842bba 100644 --- a/dlt/destinations/impl/dremio/configuration.py +++ b/dlt/destinations/impl/dremio/configuration.py @@ -3,7 +3,7 @@ from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials -from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.destination.client import DestinationClientDwhWithStagingConfiguration from dlt.common.typing import TSecretStrValue from dlt.common.utils import digest128 diff --git a/dlt/destinations/impl/dremio/dremio.py b/dlt/destinations/impl/dremio/dremio.py index e3a090c824..34565fa75e 100644 --- a/dlt/destinations/impl/dremio/dremio.py +++ b/dlt/destinations/impl/dremio/dremio.py @@ -2,7 +2,7 @@ from urllib.parse import urlparse from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( HasFollowupJobs, PreparedTableSchema, TLoadJobState, diff --git a/dlt/destinations/impl/dremio/sql_client.py b/dlt/destinations/impl/dremio/sql_client.py index 030009c74b..37c36cd865 100644 --- a/dlt/destinations/impl/dremio/sql_client.py +++ b/dlt/destinations/impl/dremio/sql_client.py @@ -19,7 +19,7 @@ raise_open_connection_error, ) from dlt.destinations.typing import DBApi, DBTransaction, DataFrame -from dlt.common.destination.reference import DBApiCursor +from dlt.common.destination.dataset import DBApiCursor class DremioCursorImpl(DBApiCursorImpl): diff --git a/dlt/destinations/impl/duckdb/configuration.py b/dlt/destinations/impl/duckdb/configuration.py index 692d0eb8e3..649d930214 100644 --- a/dlt/destinations/impl/duckdb/configuration.py +++ b/dlt/destinations/impl/duckdb/configuration.py @@ -10,7 +10,7 @@ from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.configuration.specs.base_configuration import NotResolved from dlt.common.configuration.specs.exceptions import InvalidConnectionString -from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.destination.client import DestinationClientDwhWithStagingConfiguration from dlt.common.pipeline import SupportsPipeline from dlt.destinations.impl.duckdb.exceptions import InvalidInMemoryDuckdbCredentials diff --git a/dlt/destinations/impl/duckdb/duck.py b/dlt/destinations/impl/duckdb/duck.py index 2b3370270b..d6ffeb5f34 100644 --- a/dlt/destinations/impl/duckdb/duck.py +++ b/dlt/destinations/impl/duckdb/duck.py @@ -1,9 +1,8 @@ from typing import Dict, Optional from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.exceptions import TerminalValueError -from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.destination.reference import ( +from dlt.common.schema import TColumnHint, Schema +from dlt.common.destination.client import ( PreparedTableSchema, RunnableLoadJob, HasFollowupJobs, diff --git a/dlt/destinations/impl/duckdb/sql_client.py b/dlt/destinations/impl/duckdb/sql_client.py index ee73965df6..3ffb24bbca 100644 --- a/dlt/destinations/impl/duckdb/sql_client.py +++ b/dlt/destinations/impl/duckdb/sql_client.py @@ -20,7 +20,7 @@ ) from dlt.destinations.impl.duckdb.configuration import DuckDbBaseCredentials -from dlt.common.destination.reference import DBApiCursor +from dlt.common.destination.dataset import DBApiCursor class DuckDBDBApiCursorImpl(DBApiCursorImpl): diff --git a/dlt/destinations/impl/dummy/configuration.py b/dlt/destinations/impl/dummy/configuration.py index a066479294..0484ada1a8 100644 --- a/dlt/destinations/impl/dummy/configuration.py +++ b/dlt/destinations/impl/dummy/configuration.py @@ -3,7 +3,7 @@ from dlt.common.configuration import configspec from dlt.common.destination import TLoaderFileFormat -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( DestinationClientConfiguration, CredentialsConfiguration, ) diff --git a/dlt/destinations/impl/dummy/dummy.py b/dlt/destinations/impl/dummy/dummy.py index aec5a80b7d..ce83caefd9 100644 --- a/dlt/destinations/impl/dummy/dummy.py +++ b/dlt/destinations/impl/dummy/dummy.py @@ -23,7 +23,7 @@ DestinationTerminalException, DestinationTransientException, ) -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( HasFollowupJobs, FollowupJobRequest, PreparedTableSchema, diff --git a/dlt/destinations/impl/filesystem/configuration.py b/dlt/destinations/impl/filesystem/configuration.py index 09dc40e9d4..2133feeb62 100644 --- a/dlt/destinations/impl/filesystem/configuration.py +++ b/dlt/destinations/impl/filesystem/configuration.py @@ -4,7 +4,7 @@ from dlt.common import logger from dlt.common.configuration import configspec, resolve_type -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( CredentialsConfiguration, DestinationClientStagingConfiguration, ) diff --git a/dlt/destinations/impl/filesystem/factory.py b/dlt/destinations/impl/filesystem/factory.py index 906bd157e4..fdfb24921d 100644 --- a/dlt/destinations/impl/filesystem/factory.py +++ b/dlt/destinations/impl/filesystem/factory.py @@ -1,7 +1,7 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext, TLoaderFileFormat -from dlt.common.destination.reference import DEFAULT_FILE_LAYOUT +from dlt.common.destination.client import DEFAULT_FILE_LAYOUT from dlt.common.schema.typing import TLoaderMergeStrategy, TTableSchema from dlt.common.storages.configuration import FileSystemCredentials diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index ccf764811b..780abdeb58 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -39,10 +39,9 @@ ) from dlt.destinations.sql_client import WithSqlClient, SqlClientBase from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( FollowupJobRequest, PreparedTableSchema, - SupportsReadableRelation, TLoadJobState, RunnableLoadJob, JobClientBase, @@ -53,6 +52,7 @@ StateInfo, LoadJob, ) +from dlt.common.destination.dataset import SupportsReadableRelation from dlt.common.destination.exceptions import DestinationUndefinedEntity from dlt.destinations.job_impl import ( diff --git a/dlt/destinations/impl/filesystem/sql_client.py b/dlt/destinations/impl/filesystem/sql_client.py index e6b84343bb..9b9b890e5a 100644 --- a/dlt/destinations/impl/filesystem/sql_client.py +++ b/dlt/destinations/impl/filesystem/sql_client.py @@ -11,7 +11,7 @@ from contextlib import contextmanager -from dlt.common.destination.reference import DBApiCursor +from dlt.common.destination.dataset import DBApiCursor from dlt.common.storages.fsspec_filesystem import AZURE_BLOB_STORAGE_PROTOCOLS from dlt.destinations.sql_client import raise_database_error diff --git a/dlt/destinations/impl/lancedb/configuration.py b/dlt/destinations/impl/lancedb/configuration.py index 33642268c1..67934895f9 100644 --- a/dlt/destinations/impl/lancedb/configuration.py +++ b/dlt/destinations/impl/lancedb/configuration.py @@ -6,7 +6,7 @@ BaseConfiguration, CredentialsConfiguration, ) -from dlt.common.destination.reference import DestinationClientDwhConfiguration +from dlt.common.destination.client import DestinationClientDwhConfiguration from dlt.common.typing import TSecretStrValue from dlt.common.utils import digest128 diff --git a/dlt/destinations/impl/lancedb/exceptions.py b/dlt/destinations/impl/lancedb/exceptions.py index 35b86ce76c..9bea82b2be 100644 --- a/dlt/destinations/impl/lancedb/exceptions.py +++ b/dlt/destinations/impl/lancedb/exceptions.py @@ -9,7 +9,7 @@ DestinationUndefinedEntity, DestinationTerminalException, ) -from dlt.common.destination.reference import JobClientBase +from dlt.common.destination.client import JobClientBase from dlt.common.typing import TFun diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index e484435720..0886c03b6f 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -32,7 +32,7 @@ DestinationTransientException, DestinationTerminalException, ) -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( JobClientBase, PreparedTableSchema, WithStateSync, diff --git a/dlt/destinations/impl/motherduck/configuration.py b/dlt/destinations/impl/motherduck/configuration.py index c7aaf4702e..09b7f20cf7 100644 --- a/dlt/destinations/impl/motherduck/configuration.py +++ b/dlt/destinations/impl/motherduck/configuration.py @@ -5,7 +5,7 @@ from dlt.version import __version__ from dlt.common.configuration import configspec -from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.destination.client import DestinationClientDwhWithStagingConfiguration from dlt.common.destination.exceptions import DestinationTerminalException from dlt.common.typing import TSecretStrValue from dlt.common.utils import digest128 diff --git a/dlt/destinations/impl/mssql/configuration.py b/dlt/destinations/impl/mssql/configuration.py index c95f52a566..7a92a61e34 100644 --- a/dlt/destinations/impl/mssql/configuration.py +++ b/dlt/destinations/impl/mssql/configuration.py @@ -7,7 +7,7 @@ from dlt.common.typing import TSecretStrValue from dlt.common.exceptions import SystemConfigurationException -from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.destination.client import DestinationClientDwhWithStagingConfiguration @configspec(init=False) diff --git a/dlt/destinations/impl/mssql/mssql.py b/dlt/destinations/impl/mssql/mssql.py index 7b48a6b551..ad3589b58e 100644 --- a/dlt/destinations/impl/mssql/mssql.py +++ b/dlt/destinations/impl/mssql/mssql.py @@ -1,6 +1,6 @@ from typing import Dict, Optional, Sequence, List, Any -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( FollowupJobRequest, PreparedTableSchema, ) diff --git a/dlt/destinations/impl/mssql/sql_client.py b/dlt/destinations/impl/mssql/sql_client.py index 9f05b88bb5..0305a2cb90 100644 --- a/dlt/destinations/impl/mssql/sql_client.py +++ b/dlt/destinations/impl/mssql/sql_client.py @@ -22,7 +22,7 @@ ) from dlt.destinations.impl.mssql.configuration import MsSqlCredentials -from dlt.common.destination.reference import DBApiCursor +from dlt.common.destination.dataset import DBApiCursor def handle_datetimeoffset(dto_value: bytes) -> datetime: diff --git a/dlt/destinations/impl/postgres/configuration.py b/dlt/destinations/impl/postgres/configuration.py index 14eb499f89..c88cef813e 100644 --- a/dlt/destinations/impl/postgres/configuration.py +++ b/dlt/destinations/impl/postgres/configuration.py @@ -7,7 +7,7 @@ from dlt.common.utils import digest128 from dlt.common.typing import TSecretStrValue -from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.destination.client import DestinationClientDwhWithStagingConfiguration @configspec(init=False) diff --git a/dlt/destinations/impl/postgres/postgres.py b/dlt/destinations/impl/postgres/postgres.py index 3d54b59f93..d784e4d44f 100644 --- a/dlt/destinations/impl/postgres/postgres.py +++ b/dlt/destinations/impl/postgres/postgres.py @@ -6,7 +6,7 @@ from dlt.common.destination.exceptions import ( DestinationInvalidFileFormat, ) -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( HasFollowupJobs, PreparedTableSchema, RunnableLoadJob, diff --git a/dlt/destinations/impl/postgres/sql_client.py b/dlt/destinations/impl/postgres/sql_client.py index a97c8511f1..bc4345bfa5 100644 --- a/dlt/destinations/impl/postgres/sql_client.py +++ b/dlt/destinations/impl/postgres/sql_client.py @@ -17,7 +17,8 @@ DatabaseTransientException, DatabaseUndefinedRelation, ) -from dlt.common.destination.reference import DBApiCursor +from dlt.common.destination.dataset import DBApiCursor + from dlt.destinations.typing import DBApi, DBTransaction from dlt.destinations.sql_client import ( DBApiCursorImpl, diff --git a/dlt/destinations/impl/qdrant/configuration.py b/dlt/destinations/impl/qdrant/configuration.py index baf5e5dc59..b23fdb0a7d 100644 --- a/dlt/destinations/impl/qdrant/configuration.py +++ b/dlt/destinations/impl/qdrant/configuration.py @@ -7,7 +7,7 @@ BaseConfiguration, CredentialsConfiguration, ) -from dlt.common.destination.reference import DestinationClientDwhConfiguration +from dlt.common.destination.client import DestinationClientDwhConfiguration from dlt.destinations.impl.qdrant.exceptions import InvalidInMemoryQdrantCredentials if TYPE_CHECKING: diff --git a/dlt/destinations/impl/qdrant/factory.py b/dlt/destinations/impl/qdrant/factory.py index 49c4511c8d..fd8a2d47a1 100644 --- a/dlt/destinations/impl/qdrant/factory.py +++ b/dlt/destinations/impl/qdrant/factory.py @@ -1,7 +1,6 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext -from dlt.common.destination.reference import TDestinationConfig from dlt.common.normalizers.naming import NamingConvention from dlt.destinations.impl.qdrant.configuration import QdrantCredentials, QdrantClientConfiguration diff --git a/dlt/destinations/impl/qdrant/qdrant_job_client.py b/dlt/destinations/impl/qdrant/qdrant_job_client.py index 6c8de52f98..035b9e00a9 100644 --- a/dlt/destinations/impl/qdrant/qdrant_job_client.py +++ b/dlt/destinations/impl/qdrant/qdrant_job_client.py @@ -14,7 +14,7 @@ version_table, ) from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( PreparedTableSchema, TLoadJobState, RunnableLoadJob, diff --git a/dlt/destinations/impl/redshift/redshift.py b/dlt/destinations/impl/redshift/redshift.py index b1aa37ce6a..00e093a516 100644 --- a/dlt/destinations/impl/redshift/redshift.py +++ b/dlt/destinations/impl/redshift/redshift.py @@ -14,7 +14,7 @@ from typing import Dict, List, Optional, Sequence -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( FollowupJobRequest, CredentialsConfiguration, PreparedTableSchema, diff --git a/dlt/destinations/impl/snowflake/configuration.py b/dlt/destinations/impl/snowflake/configuration.py index 2e589ea095..aeec71afd2 100644 --- a/dlt/destinations/impl/snowflake/configuration.py +++ b/dlt/destinations/impl/snowflake/configuration.py @@ -9,7 +9,7 @@ from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.common.configuration import configspec -from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.destination.client import DestinationClientDwhWithStagingConfiguration from dlt.common.utils import digest128 diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index 786cdc0b77..a63f41a28b 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -4,7 +4,7 @@ from dlt.common import logger from dlt.common.data_writers.configuration import CsvFormatConfiguration from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( HasFollowupJobs, LoadJob, PreparedTableSchema, diff --git a/dlt/destinations/impl/snowflake/sql_client.py b/dlt/destinations/impl/snowflake/sql_client.py index 22e27ea48b..1c35061659 100644 --- a/dlt/destinations/impl/snowflake/sql_client.py +++ b/dlt/destinations/impl/snowflake/sql_client.py @@ -18,7 +18,7 @@ ) from dlt.destinations.typing import DBApi, DBTransaction, DataFrame from dlt.destinations.impl.snowflake.configuration import SnowflakeCredentials -from dlt.common.destination.reference import DBApiCursor +from dlt.common.destination.dataset import DBApiCursor class SnowflakeCursorImpl(DBApiCursorImpl): diff --git a/dlt/destinations/impl/sqlalchemy/configuration.py b/dlt/destinations/impl/sqlalchemy/configuration.py index b26c87dfac..9bdeac11ba 100644 --- a/dlt/destinations/impl/sqlalchemy/configuration.py +++ b/dlt/destinations/impl/sqlalchemy/configuration.py @@ -3,7 +3,7 @@ from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials -from dlt.common.destination.reference import DestinationClientDwhConfiguration +from dlt.common.destination.client import DestinationClientDwhConfiguration if TYPE_CHECKING: from sqlalchemy.engine import Engine, Dialect diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index 27c4f2f1f9..8372e3518e 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -20,7 +20,6 @@ from sqlalchemy.exc import ResourceClosedError from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import PreparedTableSchema from dlt.destinations.exceptions import ( DatabaseUndefinedRelation, DatabaseTerminalException, @@ -28,7 +27,7 @@ LoadClientNotConnected, DatabaseException, ) -from dlt.common.destination.reference import DBApiCursor +from dlt.common.destination.dataset import DBApiCursor from dlt.destinations.typing import DBTransaction from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.impl.sqlalchemy.configuration import SqlalchemyCredentials diff --git a/dlt/destinations/impl/sqlalchemy/load_jobs.py b/dlt/destinations/impl/sqlalchemy/load_jobs.py index 3cfd6bd910..7028c9d280 100644 --- a/dlt/destinations/impl/sqlalchemy/load_jobs.py +++ b/dlt/destinations/impl/sqlalchemy/load_jobs.py @@ -3,7 +3,7 @@ import sqlalchemy as sa -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( RunnableLoadJob, HasFollowupJobs, PreparedTableSchema, diff --git a/dlt/destinations/impl/sqlalchemy/merge_job.py b/dlt/destinations/impl/sqlalchemy/merge_job.py index 5360939ba0..f16a2c82e6 100644 --- a/dlt/destinations/impl/sqlalchemy/merge_job.py +++ b/dlt/destinations/impl/sqlalchemy/merge_job.py @@ -4,7 +4,7 @@ import sqlalchemy as sa from dlt.destinations.sql_jobs import SqlMergeFollowupJob -from dlt.common.destination.reference import PreparedTableSchema, DestinationCapabilitiesContext +from dlt.common.destination import PreparedTableSchema, DestinationCapabilitiesContext from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient from dlt.common.schema.utils import ( get_columns_names_with_prop, diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py index ab73ecf502..956a3d6acf 100644 --- a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -6,7 +6,7 @@ from dlt.common.json import json from dlt.common import logger from dlt.common import pendulum -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( JobClientBase, LoadJob, StorageSchemaInfo, diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index 15c979bafa..e912371579 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -5,7 +5,7 @@ from urllib.parse import urlparse, urlunparse from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( PreparedTableSchema, SupportsStagingDestination, FollowupJobRequest, diff --git a/dlt/destinations/impl/weaviate/configuration.py b/dlt/destinations/impl/weaviate/configuration.py index 1a053e41f4..55c9f63810 100644 --- a/dlt/destinations/impl/weaviate/configuration.py +++ b/dlt/destinations/impl/weaviate/configuration.py @@ -5,7 +5,7 @@ from dlt.common.configuration import configspec, NotResolved from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration -from dlt.common.destination.reference import DestinationClientDwhConfiguration +from dlt.common.destination.client import DestinationClientDwhConfiguration from dlt.common.utils import digest128 TWeaviateBatchConsistency = Literal["ONE", "QUORUM", "ALL"] diff --git a/dlt/destinations/impl/weaviate/weaviate_client.py b/dlt/destinations/impl/weaviate/weaviate_client.py index e9d6a76a17..a152e2d322 100644 --- a/dlt/destinations/impl/weaviate/weaviate_client.py +++ b/dlt/destinations/impl/weaviate/weaviate_client.py @@ -38,7 +38,7 @@ version_table, ) from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( PreparedTableSchema, TLoadJobState, RunnableLoadJob, diff --git a/dlt/destinations/insert_job_client.py b/dlt/destinations/insert_job_client.py index aa608ca2ad..225662f0df 100644 --- a/dlt/destinations/insert_job_client.py +++ b/dlt/destinations/insert_job_client.py @@ -1,6 +1,6 @@ -from typing import Any, Iterator, List +from typing import Iterator, List -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( PreparedTableSchema, RunnableLoadJob, HasFollowupJobs, diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 888c80c006..d050562a94 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -39,7 +39,7 @@ from dlt.common.storages.load_package import LoadJobInfo, ParsedLoadJobFileName from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables from dlt.common.schema import TColumnHint -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( PreparedTableSchema, StateInfo, StorageSchemaInfo, diff --git a/dlt/destinations/job_impl.py b/dlt/destinations/job_impl.py index 3f261bafed..460db05f1f 100644 --- a/dlt/destinations/job_impl.py +++ b/dlt/destinations/job_impl.py @@ -1,10 +1,10 @@ from abc import ABC, abstractmethod import os import tempfile # noqa: 251 -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List from dlt.common.json import json -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( HasFollowupJobs, TLoadJobState, RunnableLoadJob, diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index 7d1728b43d..cf9cdf0c40 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -25,7 +25,7 @@ from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.utils import concat_strings_with_limit -from dlt.common.destination.reference import JobClientBase +from dlt.common.destination.client import JobClientBase from dlt.destinations.exceptions import ( DestinationConnectionError, @@ -38,7 +38,7 @@ DBTransaction, ArrowTable, ) -from dlt.common.destination.reference import DBApiCursor +from dlt.common.destination.dataset import DBApiCursor class TJobQueryTags(TypedDict): diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index a389c13170..9e4ee48191 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -2,7 +2,7 @@ import yaml from dlt.common.time import ensure_pendulum_datetime -from dlt.common.destination.reference import PreparedTableSchema +from dlt.common.destination import PreparedTableSchema from dlt.common.destination.utils import resolve_merge_strategy from dlt.common.schema.typing import ( diff --git a/dlt/destinations/type_mapping.py b/dlt/destinations/type_mapping.py index d615675fa6..2f3b9b1062 100644 --- a/dlt/destinations/type_mapping.py +++ b/dlt/destinations/type_mapping.py @@ -1,7 +1,7 @@ from typing import Tuple, Dict, Optional from dlt.common import logger -from dlt.common.destination.reference import PreparedTableSchema +from dlt.common.destination import PreparedTableSchema from dlt.common.schema.typing import ( TColumnSchema, TDataType, diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index 8444f70489..bed2286132 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -173,6 +173,7 @@ def with_args( # also remember original source function ovr._f = self._f + # ovr._deco_f = self._deco_f # try to bind _f ovr.wrap() return ovr @@ -727,6 +728,7 @@ def decorator( # assign spec to "f" set_fun_spec(f, SPEC) + factory = None # register non inner resources as source with single resource in it if not is_inner_resource: # a source function for the source wrapper, args that go to source are forwarded @@ -752,13 +754,17 @@ def _source( ) .bind(_source) ) - # remove name and section overrides from the wrapper so resource is not unnecessarily renamed - factory.name = None - factory.section = None # mod the reference to keep the right spec factory._ref.SPEC = SPEC - return wrap_standalone(resource_name, source_section, f) + deco: Callable[TResourceFunParams, TDltResourceImpl] = wrap_standalone( + resource_name, source_section, f + ) + # associate source factory with the decorated function for the standalone=True resource + # this provides access to standalone resources in the same way as to sources via SourceReference + deco._factory = factory # type: ignore[attr-defined] + + return deco # if data is callable or none use decorator if data is None: diff --git a/dlt/extract/exceptions.py b/dlt/extract/exceptions.py index e832833428..821ce3fff0 100644 --- a/dlt/extract/exceptions.py +++ b/dlt/extract/exceptions.py @@ -415,11 +415,11 @@ def __init__(self, ref: Sequence[str]) -> None: super().__init__(msg) -# class InvalidDestinationReference(DestinationException): -# def __init__(self, destination_module: Any) -> None: -# self.destination_module = destination_module -# msg = f"Destination {destination_module} is not a valid destination module." -# super().__init__(msg) +class InvalidSourceReference(DltSourceException): + def __init__(self, ref: str) -> None: + self.ref = ref + msg = f"Destination reference {ref} has invalid format." + super().__init__(msg) class IncrementalUnboundError(DltResourceException): diff --git a/dlt/extract/resource.py b/dlt/extract/resource.py index 366e6e1a88..f08666c75c 100644 --- a/dlt/extract/resource.py +++ b/dlt/extract/resource.py @@ -551,9 +551,7 @@ def __or__(self, transform: Union["DltResource", AnyFun]) -> "DltResource": else: return self.add_map(transform) - def __ror__( - self: TDltResourceImpl, data: Union[Iterable[Any], Iterator[Any]] - ) -> TDltResourceImpl: + def __ror__(self, data: Union[Iterable[Any], Iterator[Any]]) -> Self: """Allows to pipe data from across resources and transform functions with | operator This is the RIGHT side OR so the self may not be a resource and the LEFT must be an object that does not implement | ie. a list @@ -605,9 +603,7 @@ def _eject_config(self) -> bool: return True return False - def _inject_config( - self, incremental_from_hints_override: Optional[bool] = None - ) -> "DltResource": + def _inject_config(self, incremental_from_hints_override: Optional[bool] = None) -> Self: """Wraps the pipe generation step in incremental and config injection wrappers and adds pipe step with Incremental transform. """ @@ -784,6 +780,7 @@ def validate_transformer_generator_function(f: AnyFun) -> int: return 0 +# DltResource = _DltResource[Any] # produce Empty resource singleton DltResource.Empty = DltResource(Pipe(None), None, False) TUnboundDltResource = Callable[..., DltResource] diff --git a/dlt/extract/source.py b/dlt/extract/source.py index 1d984de3e4..ff37fad31c 100644 --- a/dlt/extract/source.py +++ b/dlt/extract/source.py @@ -3,7 +3,19 @@ from importlib import import_module import makefun import inspect -from typing import Dict, Iterable, Iterator, List, Sequence, Tuple, Any, Generic +from typing import ( + Callable, + Dict, + Iterable, + Iterator, + List, + Sequence, + Tuple, + Any, + Generic, + cast, + overload, +) from typing_extensions import Self, Protocol, TypeVar from types import ModuleType from typing import Dict, Type, ClassVar @@ -21,7 +33,7 @@ from dlt.common.schema import Schema from dlt.common.schema.typing import TColumnName, TSchemaContract from dlt.common.schema.utils import normalize_table_identifiers -from dlt.common.typing import StrAny, TDataItem, ParamSpec +from dlt.common.typing import StrAny, TDataItem, ParamSpec, AnyFun from dlt.common.configuration.container import Container from dlt.common.pipeline import ( PipelineContext, @@ -40,6 +52,7 @@ from dlt.extract.resource import DltResource from dlt.extract.exceptions import ( DataItemRequiredForDynamicTableHints, + InvalidSourceReference, ResourcesNotFoundError, DeletingResourcesNotSupported, InvalidParallelResourceDataType, @@ -509,7 +522,7 @@ def with_args( schema_contract: TSchemaContract = None, spec: Type[BaseConfiguration] = None, parallelized: bool = None, - _impl_cls: Type[TDltSourceImpl] = None, + _impl_cls: Type[TDltSourceImpl] = DltSource, # type: ignore[assignment] ) -> Self: """Overrides default decorator arguments that will be used to when DltSource instance and returns modified clone.""" @@ -545,75 +558,170 @@ def __init__( self.name = name self.context = Container()[PluggableRunContext].context + # @staticmethod + # def normalize_type(source_type: str) -> str: + # """Normalizes source type string into a canonical form. Assumes that type names without dots or with two dots belong to built-in sources.""" + # parts = source_type.rsplit(".", 1) + # if len(parts) == 1: + # source_type = f"dlt.sources.{source_type}.{source_type}" + # elif len(parts) == 2: + # source_type = f"dlt.sources.{source_type}" + # return source_type + @staticmethod - def to_fully_qualified_ref(ref: str) -> List[str]: + def to_fully_qualified_refs(ref: str) -> List[str]: """Converts ref into fully qualified form, return one or more alternatives for shorthand notations. - Run context is injected in needed. + Run context is injected if needed. Following formats are recognized + - context_name.'sources'.section.name (fully qualified) + - 'sources'.section.name + - section.name + - name """ ref_split = ref.split(".") - if len(ref_split) > 3: - return [] + ref_parts = len(ref_split) + if ref_parts < 3 or (ref_parts == 3 and ref_split[0] == known_sections.SOURCES): + # context name is needed + refs = [] + run_names = [Container()[PluggableRunContext].context.name] + # always look in default run context + if run_names[0] != RunContext.CONTEXT_NAME: + run_names.append(RunContext.CONTEXT_NAME) + for run_name in run_names: + # expand shorthand notation + if ref_parts == 1: + refs.append(f"{run_name}.{known_sections.SOURCES}.{ref}.{ref}") + elif ref_parts == 2: + # for ref with two parts two options are possible + refs.append(f"{run_name}.{known_sections.SOURCES}.{ref}") + elif ref_parts == 3: + refs.append(f"{run_name}.{ref}") + # refs.extend([f"{run_name}.sources.{ref}", f"{ref_split[0]}.sources.{ref_split[1]}.{ref_split[1]}"]) + return refs # fully qualified path - if len(ref_split) == 3: + if len(ref_split) == 4 and ref_split[1] == known_sections.SOURCES: return [ref] - # context name is needed - refs = [] - run_names = [Container()[PluggableRunContext].context.name] - # always look in default run context - if run_names[0] != RunContext.CONTEXT_NAME: - run_names.append(RunContext.CONTEXT_NAME) - for run_name in run_names: - # expand shorthand notation - if len(ref_split) == 1: - refs.append(f"{run_name}.{ref}.{ref}") - else: - # for ref with two parts two options are possible - refs.extend([f"{run_name}.{ref}", f"{ref_split[0]}.{ref_split[1]}.{ref_split[1]}"]) - return refs + return [] @classmethod def register(cls, ref_obj: "SourceReference") -> None: - ref = f"{ref_obj.context.name}.{ref_obj.section}.{ref_obj.name}" + ref = f"{ref_obj.context.name}.{known_sections.SOURCES}.{ref_obj.section}.{ref_obj.name}" if ref in cls.SOURCES: logger.info(f"A source with ref {ref} is already registered and will be overwritten") cls.SOURCES[ref] = ref_obj @classmethod def find(cls, ref: str) -> "SourceReference": - refs = cls.to_fully_qualified_ref(ref) + refs = cls.to_fully_qualified_refs(ref) + + if not refs: + raise InvalidSourceReference(ref) for ref_ in refs: if wrapper := cls.SOURCES.get(ref_): return wrapper raise KeyError(refs) + @overload + @classmethod + def from_reference( + cls, + ref: str, + /, + name: str = None, + section: str = None, + max_table_nesting: int = None, + root_key: bool = False, + schema: Schema = None, + schema_contract: TSchemaContract = None, + spec: Type[BaseConfiguration] = None, + parallelized: bool = None, + _impl_sig: None = ..., + _impl_cls: Type[TDltSourceImpl] = None, + ) -> SourceFactory[Any, TDltSourceImpl]: ... + + @overload + @classmethod + def from_reference( + cls, + ref: str, + /, + name: str = None, + section: str = None, + max_table_nesting: int = None, + root_key: bool = False, + schema: Schema = None, + schema_contract: TSchemaContract = None, + spec: Type[BaseConfiguration] = None, + parallelized: bool = None, + _impl_sig: Callable[TSourceFunParams, Any] = None, + _impl_cls: Type[TDltSourceImpl] = None, + ) -> SourceFactory[TSourceFunParams, TDltSourceImpl]: ... + @classmethod - def from_reference(cls, ref: str) -> AnySourceFactory: + def from_reference( + cls, + ref: str, + /, + name: str = None, + section: str = None, + max_table_nesting: int = None, + root_key: bool = False, + schema: Schema = None, + schema_contract: TSchemaContract = None, + spec: Type[BaseConfiguration] = None, + parallelized: bool = None, + _impl_sig: AnyFun = None, + _impl_cls: Type[TDltSourceImpl] = None, + ) -> Any: """Returns registered source factory or imports source module and returns a function. - Expands shorthand notation into section.name eg. "sql_database" is expanded into "sql_database.sql_database" + Expands shorthand notation into section.name eg. "sql_database" is expanded into "sql_database.sql_database". + Passes additional arguments to `with_args` of source factory """ - refs = cls.to_fully_qualified_ref(ref) + refs = cls.to_fully_qualified_refs(ref) + factory: AnySourceFactory = None for ref_ in refs: if wrapper := cls.SOURCES.get(ref_): - return wrapper.f + factory = wrapper.f + break # try to import module - if "." in ref: - try: - module_path, attr_name = ref.rsplit(".", 1) - dest_module = import_module(module_path) - factory = getattr(dest_module, attr_name) - if hasattr(factory, "with_args"): - return factory # type: ignore[no-any-return] - else: - raise ValueError(f"{attr_name} in {module_path} is of type {type(factory)}") - except MissingDependencyException: - raise - except ModuleNotFoundError: - # raise regular exception later - pass - except Exception as e: - raise UnknownSourceReference([ref]) from e - + if factory is None: + # try ref, normalized refs and ref without the context name + refs.extend(set([r.split(".", 1)[1] for r in refs])) + if "." in ref and ref not in refs: + refs = [ref] + refs + for possible_type in refs: + try: + # will expand type to import built in types + module_path, attr_name = possible_type.rsplit(".", 1) + dest_module = import_module(module_path) + factory = cast(AnySourceFactory, getattr(dest_module, attr_name)) + # standalone resource will be implemented as decorated function with the factory attached + if hasattr(factory, "_factory"): + factory = factory._factory + # make sure it is factory interface (we could check Protocol as well) + if not hasattr(factory, "with_args"): + raise ValueError(f"{attr_name} in {module_path} is of type {type(factory)}") + break + except MissingDependencyException: + raise + except ModuleNotFoundError: + # raise regular exception later + pass + except Exception as e: + raise UnknownSourceReference([ref]) from e + + if factory: + return factory.with_args( + name=name, + section=section, + max_table_nesting=max_table_nesting, + root_key=root_key, + schema=schema, + schema_contract=schema_contract, + spec=spec, + parallelized=parallelized, + _impl_cls=_impl_cls, + ) raise UnknownSourceReference(refs or [ref]) diff --git a/dlt/helpers/dbt/__init__.py b/dlt/helpers/dbt/__init__.py index fc229ed1d0..6672031171 100644 --- a/dlt/helpers/dbt/__init__.py +++ b/dlt/helpers/dbt/__init__.py @@ -4,7 +4,7 @@ import semver from dlt.common.runners import Venv -from dlt.common.destination.reference import DestinationClientDwhConfiguration +from dlt.common.destination.client import DestinationClientDwhConfiguration from dlt.common.configuration.specs import CredentialsWithDefault from dlt.common.typing import TSecretStrValue, ConfigValue from dlt.version import get_installed_requirement_string diff --git a/dlt/helpers/dbt/runner.py b/dlt/helpers/dbt/runner.py index 49c165b05d..05d56ba8f3 100644 --- a/dlt/helpers/dbt/runner.py +++ b/dlt/helpers/dbt/runner.py @@ -8,7 +8,7 @@ from dlt.common import logger from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.utils import add_config_to_env -from dlt.common.destination.reference import DestinationClientDwhConfiguration +from dlt.common.destination.client import DestinationClientDwhConfiguration from dlt.common.runners import Venv from dlt.common.runners.stdout import iter_stdout_with_result from dlt.common.typing import StrAny, TSecretStrValue diff --git a/dlt/helpers/ibis.py b/dlt/helpers/ibis.py index e15bb9bc16..1144599dac 100644 --- a/dlt/helpers/ibis.py +++ b/dlt/helpers/ibis.py @@ -1,7 +1,8 @@ from typing import cast, Any from dlt.common.exceptions import MissingDependencyException -from dlt.common.destination.reference import TDestinationReferenceArg, Destination, JobClientBase +from dlt.common.destination import TDestinationReferenceArg, Destination +from dlt.common.destination.client import JobClientBase from dlt.common.schema import Schema from dlt.destinations.sql_client import SqlClientBase @@ -79,7 +80,7 @@ def create_ibis_backend( sf_client = cast(SnowflakeClient, client) credentials = sf_client.config.credentials.to_connector_params() - con = ibis.snowflake.connect(**credentials) + con = ibis.snowflake.connect(**credentials, create_object_udfs=False) elif destination_type in ["dlt.destinations.mssql", "dlt.destinations.synapse"]: from dlt.destinations.impl.mssql.mssql import MsSqlJobClient diff --git a/dlt/helpers/streamlit_app/pages/load_info.py b/dlt/helpers/streamlit_app/pages/load_info.py index 699e786410..57a7f24a51 100644 --- a/dlt/helpers/streamlit_app/pages/load_info.py +++ b/dlt/helpers/streamlit_app/pages/load_info.py @@ -2,7 +2,7 @@ import streamlit as st from dlt.common.configuration.exceptions import ConfigFieldMissingException -from dlt.common.destination.reference import WithStateSync +from dlt.common.destination.client import WithStateSync from dlt.helpers.streamlit_app.blocks.load_info import last_load_info from dlt.helpers.streamlit_app.blocks.menu import menu from dlt.helpers.streamlit_app.widgets import stat diff --git a/dlt/load/load.py b/dlt/load/load.py index ddbc7193ed..b62f7e648c 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -27,19 +27,18 @@ from dlt.common.configuration.container import Container from dlt.common.schema import Schema from dlt.common.storages import LoadStorage -from dlt.common.destination.reference import ( +from dlt.common.destination import Destination, AnyDestination +from dlt.common.destination.client import ( DestinationClientDwhConfiguration, HasFollowupJobs, JobClientBase, WithStagingDataset, - Destination, RunnableLoadJob, LoadJob, FollowupJobRequest, TLoadJobState, DestinationClientConfiguration, SupportsStagingDestination, - AnyDestination, ) from dlt.common.destination.exceptions import ( DestinationTerminalException, diff --git a/dlt/load/utils.py b/dlt/load/utils.py index 7800955cf9..2c93a0740e 100644 --- a/dlt/load/utils.py +++ b/dlt/load/utils.py @@ -12,7 +12,7 @@ from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.common.schema import Schema, TSchemaTables from dlt.common.schema.typing import TTableSchema -from dlt.common.destination.reference import JobClientBase, WithStagingDataset, LoadJob +from dlt.common.destination.client import JobClientBase, WithStagingDataset, LoadJob from dlt.load.configuration import LoaderConfiguration from dlt.common.destination import DestinationCapabilitiesContext diff --git a/tests/cli/test_config_toml_writer.py b/tests/cli/test_config_toml_writer.py index 31c6f524a7..b288bdb6a2 100644 --- a/tests/cli/test_config_toml_writer.py +++ b/tests/cli/test_config_toml_writer.py @@ -4,7 +4,7 @@ from dlt.cli.config_toml_writer import write_value, WritableConfigValue, write_values from dlt.common.configuration.specs import configspec -from dlt.common.destination.reference import DEFAULT_FILE_LAYOUT +from dlt.common.destination.client import DEFAULT_FILE_LAYOUT EXAMPLE_COMMENT = "# please set me up!" diff --git a/tests/common/cases/destinations/null.py b/tests/common/cases/destinations/null.py index 37e87d89cf..fbb9a56c4b 100644 --- a/tests/common/cases/destinations/null.py +++ b/tests/common/cases/destinations/null.py @@ -1,11 +1,8 @@ from typing import Any, Type from dlt.common.destination.capabilities import DestinationCapabilitiesContext -from dlt.common.destination.reference import ( - Destination, - DestinationClientConfiguration, - JobClientBase, -) +from dlt.common.destination import Destination +from dlt.common.destination.client import DestinationClientConfiguration, JobClientBase class null(Destination[DestinationClientConfiguration, "JobClientBase"]): diff --git a/tests/common/configuration/test_inject.py b/tests/common/configuration/test_inject.py index 584052b6c8..092233dac6 100644 --- a/tests/common/configuration/test_inject.py +++ b/tests/common/configuration/test_inject.py @@ -30,7 +30,7 @@ ) from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContainer from dlt.common.configuration.specs.config_section_context import ConfigSectionContext -from dlt.common.reflection.spec import _get_spec_name_from_f +from dlt.common.reflection.spec import get_spec_name_from_f from dlt.common.typing import ( StrAny, TSecretStrValue, @@ -245,12 +245,13 @@ def test_inject_without_spec_kw_only() -> None: pass -def test_inject_with_auto_section(environment: Any) -> None: - environment["PIPE__VALUE"] = "test" +def test_inject_with_pipeline_section(environment: Any) -> None: + expected_value = "test" + environment["PIPE__VALUE"] = expected_value - @with_config(auto_pipeline_section=True) + @with_config(section_arg_name="pipeline_name") def f(pipeline_name=dlt.config.value, value=dlt.secrets.value): - assert value == "test" + assert value == expected_value f("pipe") @@ -258,6 +259,33 @@ def f(pipeline_name=dlt.config.value, value=dlt.secrets.value): assert get_fun_spec(f) is not None assert hasattr(get_fun_spec(f), "pipeline_name") + # also section is extended + del environment["PIPE__VALUE"] + expected_value = "test7" + environment["PIPE__PIPE__VALUE"] = expected_value + f("pipe") + + +def test_extend_sections_from_argument(environment: Any) -> None: + @with_config(sections=("datasets",), section_arg_name="dataset_name") + def f(dataset_name=dlt.config.value, value=dlt.secrets.value): + assert dataset_name == "github" + assert value == expected_value + + expected_value = "test" + environment["VALUE"] = expected_value + f("github") + + del environment["VALUE"] + expected_value = "test2" + environment["DATASETS__VALUE"] = expected_value + f("github") + + del environment["DATASETS__VALUE"] + expected_value = "test3" + environment["DATASETS__GITHUB__VALUE"] = expected_value + f("github") + @pytest.mark.skip("not implemented") def test_inject_with_spec() -> None: @@ -701,7 +729,7 @@ def stuff_test(pos_par, /, kw_par) -> None: # name is composed via __qualname__ of func assert ( - _get_spec_name_from_f(AutoNameTest.__init__) + get_spec_name_from_f(AutoNameTest.__init__) == "TestAutoDerivedSpecTypeNameAutoNameTestInitConfiguration" ) # synthesized spec present in current module diff --git a/tests/common/destination/test_reference.py b/tests/common/destination/test_reference.py index 93eef793d5..347b798dee 100644 --- a/tests/common/destination/test_reference.py +++ b/tests/common/destination/test_reference.py @@ -1,7 +1,8 @@ from typing import Dict import pytest -from dlt.common.destination.reference import DestinationClientDwhConfiguration, Destination +from dlt.common.destination import Destination +from dlt.common.destination.client import DestinationClientDwhConfiguration from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.exceptions import InvalidDestinationReference, UnknownDestinationModule from dlt.common.schema import Schema diff --git a/tests/common/reflection/test_reflect_spec.py b/tests/common/reflection/test_reflect_spec.py index dd2dcb3fc5..60a5e503ae 100644 --- a/tests/common/reflection/test_reflect_spec.py +++ b/tests/common/reflection/test_reflect_spec.py @@ -11,7 +11,7 @@ RuntimeConfiguration, ConnectionStringCredentials, ) -from dlt.common.reflection.spec import spec_from_signature, _get_spec_name_from_f +from dlt.common.reflection.spec import spec_from_signature, get_spec_name_from_f from dlt.common.reflection.utils import get_func_def_node, get_literal_defaults diff --git a/tests/extract/test_decorators.py b/tests/extract/test_decorators.py index a14b4a9602..383a4c69bd 100644 --- a/tests/extract/test_decorators.py +++ b/tests/extract/test_decorators.py @@ -23,8 +23,8 @@ from dlt.cli.source_detection import detect_source_configs from dlt.common.utils import custom_environ -from dlt.extract.decorators import DltSourceFactoryWrapper -from dlt.extract.source import SourceReference +from dlt.extract.decorators import _DltSingleSource, DltSourceFactoryWrapper +from dlt.extract.source import SourceFactory, SourceReference from dlt.extract import DltResource, DltSource from dlt.extract.exceptions import ( DynamicNameNotStandaloneResource, @@ -50,6 +50,16 @@ from tests.utils import MockableRunContext +@pytest.fixture(autouse=True, scope="module") +def preserve_sources_registry() -> Iterator[None]: + try: + reg_ = SourceReference.SOURCES + SourceReference.SOURCES = {} + yield + finally: + SourceReference.SOURCES = reg_ + + def test_default_resource() -> None: @dlt.resource def resource(): @@ -734,7 +744,7 @@ def _inner_source_2(): # unknown reference with pytest.raises(UnknownSourceReference) as ref_ex: SourceReference.from_reference("$ref") - assert ref_ex.value.ref == ["dlt.$ref.$ref"] + assert ref_ex.value.ref == ["dlt.sources.$ref.$ref", "sources.$ref.$ref"] @dlt.source(section="special") def absolute_config(init: int, mark: str = dlt.config.value, secret: str = dlt.secrets.value): @@ -750,6 +760,61 @@ def absolute_config(init: int, mark: str = dlt.config.value, secret: str = dlt.s assert list(source) == ["resourse", "resourse", "resourse", 100, "ma", "sourse"] +def test_source_reference_with_args() -> None: + ref_t = SourceReference.from_reference( + "shorthand", section="changed", _impl_sig=with_shorthand_registry + ) + assert ref_t.section == "changed" # type: ignore[attr-defined] + # here source has correctly typed signature from with_shorthand_registry + source = ref_t(["A", "B"]) + assert source.section == "changed" + assert list(source) == ["A", "B"] + + # get reference for resource + ref_r_t = SourceReference.from_reference( + "test_decorators.res_reg_with_secret", + _impl_cls=_DltSingleSource, + _impl_sig=res_reg_with_secret, + ) + source_s_r = ref_r_t(secretz="P") + assert list(source_s_r) == ["P", "P", "P"] + assert isinstance(source_s_r, _DltSingleSource) + assert list(source_s_r.single_resource) == ["P", "P", "P"] + + +def test_source_reference_import_core() -> None: + SourceReference.SOURCES.clear() + # auto import by shorthand + ref = SourceReference.from_reference("rest_api") + assert isinstance(ref, DltSourceFactoryWrapper) + assert len(SourceReference.SOURCES) == 1 + + # auto import by extended shorthand + ref = SourceReference.from_reference("sql_database.sql_table") + assert len(SourceReference.SOURCES) == 3 + + # auto import by full reference + ref = SourceReference.from_reference("dlt.sources.filesystem.filesystem") + assert len(SourceReference.SOURCES) == 9 + + +def test_source_reference_auto_import() -> None: + SourceReference.SOURCES.clear() + # make sure to import resource first + ref = SourceReference.from_reference( + "tests.extract.cases.section_source.named_module.resource_f_2" + ) + assert ref.section == "name_overridden" # type: ignore[attr-defined] + assert list(ref("A")) == ["A"] + # TODO: fix double references (with renamed section and without, should be only 2 sections here) + assert len(SourceReference.SOURCES) == 4 + + ref = SourceReference.from_reference( + "tests.extract.cases.section_source.named_module.source_f_1" + ) + assert ref.section == "name_overridden" # type: ignore[attr-defined] + + def test_source_reference_with_context() -> None: ctx = PluggableRunContext() mock = MockableRunContext.from_context(ctx.context) @@ -765,7 +830,11 @@ def test_source_reference_with_context() -> None: # unknown reference with pytest.raises(UnknownSourceReference) as ref_ex: SourceReference.from_reference("$ref") - assert ref_ex.value.ref == ["mock.$ref.$ref", "dlt.$ref.$ref"] + assert ref_ex.value.ref == [ + "mock.sources.$ref.$ref", + "dlt.sources.$ref.$ref", + "sources.$ref.$ref", + ] with pytest.raises(UnknownSourceReference) as ref_ex: SourceReference.from_reference("mock.$ref.$ref") assert ref_ex.value.ref == ["mock.$ref.$ref"] @@ -777,10 +846,10 @@ def with_shorthand_registry(data): ref = SourceReference.from_reference("shorthand") assert list(ref(["C", "x"])) == ["x", "C"] - ref = SourceReference.from_reference("mock.shorthand.shorthand") + ref = SourceReference.from_reference("mock.sources.shorthand.shorthand") assert list(ref(["C", "x"])) == ["x", "C"] # from dlt package - ref = SourceReference.from_reference("dlt.shorthand.shorthand") + ref = SourceReference.from_reference("dlt.sources.shorthand.shorthand") assert list(ref(["C", "x"])) == ["C", "x"] @@ -940,14 +1009,14 @@ def no_args(): return dlt.resource([1, 2], name="data") # there is a spec even if no arguments - SPEC = SourceReference.find("dlt.test_decorators.no_args").SPEC + SPEC = SourceReference.find("dlt.sources.test_decorators.no_args").SPEC assert SPEC # source names are used to index detected sources _, _, checked = detect_source_configs(SourceReference.SOURCES, "", ()) assert "no_args" in checked - SPEC = SourceReference.find("dlt.test_decorators.not_args_r").SPEC + SPEC = SourceReference.find("dlt.sources.test_decorators.not_args_r").SPEC assert SPEC _, _, checked = detect_source_configs(SourceReference.SOURCES, "", ()) assert "not_args_r" in checked @@ -956,7 +1025,7 @@ def no_args(): def not_args_r_i(): yield from [1, 2, 3] - assert "dlt.test_decorators.not_args_r_i" not in SourceReference.SOURCES + assert "dlt.sources.test_decorators.not_args_r_i" not in SourceReference.SOURCES # you can call those assert list(no_args()) == [1, 2] @@ -1088,7 +1157,7 @@ def test_reference_registered_resource(res: DltResource) -> None: assert res_ref.SPEC is res.SPEC else: ref = res.__name__ - # create source with single res. + # create source with single res factory = SourceReference.from_reference(f"test_decorators.{ref}") # pass explicit config source = factory(init=1, secret_end=3) diff --git a/tests/helpers/dbt_tests/local/utils.py b/tests/helpers/dbt_tests/local/utils.py index 8fd3dba44f..2fea422d6d 100644 --- a/tests/helpers/dbt_tests/local/utils.py +++ b/tests/helpers/dbt_tests/local/utils.py @@ -2,7 +2,7 @@ from typing import Iterator, NamedTuple from dlt.common.configuration.utils import add_config_to_env -from dlt.common.destination.reference import DestinationClientDwhConfiguration +from dlt.common.destination.client import DestinationClientDwhConfiguration from dlt.common.runners import Venv from dlt.common.typing import StrAny diff --git a/tests/load/bigquery/test_bigquery_client.py b/tests/load/bigquery/test_bigquery_client.py index 10ee55cc6c..ae6aa45cda 100644 --- a/tests/load/bigquery/test_bigquery_client.py +++ b/tests/load/bigquery/test_bigquery_client.py @@ -19,9 +19,8 @@ from dlt.common.schema.utils import new_table from dlt.common.storages import FileStorage from dlt.common.utils import digest128, uniq_id, custom_environ -from dlt.common.destination.reference import RunnableLoadJob +from dlt.common.destination.client import RunnableLoadJob from dlt.destinations.impl.bigquery.bigquery import BigQueryClient, BigQueryClientConfiguration -from dlt.destinations.exceptions import LoadJobNotExistsException, LoadJobTerminalException from dlt.destinations.impl.bigquery.bigquery_adapter import ( AUTODETECT_SCHEMA_HINT, diff --git a/tests/load/duckdb/test_duckdb_client.py b/tests/load/duckdb/test_duckdb_client.py index 652f75772a..23aaff0263 100644 --- a/tests/load/duckdb/test_duckdb_client.py +++ b/tests/load/duckdb/test_duckdb_client.py @@ -5,7 +5,7 @@ import dlt from dlt.common.configuration.resolve import resolve_configuration from dlt.common.configuration.utils import get_resolved_traces -from dlt.common.destination.reference import Destination +from dlt.common.destination import Destination from dlt.common.utils import set_working_dir from dlt.destinations.exceptions import DatabaseUndefinedRelation diff --git a/tests/load/filesystem/utils.py b/tests/load/filesystem/utils.py index bb4153da5c..f23d3f5fdc 100644 --- a/tests/load/filesystem/utils.py +++ b/tests/load/filesystem/utils.py @@ -14,7 +14,7 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_section_context import ConfigSectionContext -from dlt.common.destination.reference import RunnableLoadJob +from dlt.common.destination.client import RunnableLoadJob from dlt.common.pendulum import timedelta, __utcnow from dlt.destinations import filesystem from dlt.destinations.impl.filesystem.filesystem import FilesystemClient diff --git a/tests/load/pipeline/test_clickhouse.py b/tests/load/pipeline/test_clickhouse.py index 7d7a821445..b66763da0e 100644 --- a/tests/load/pipeline/test_clickhouse.py +++ b/tests/load/pipeline/test_clickhouse.py @@ -3,7 +3,7 @@ import pytest import dlt -from dlt.common.destination.reference import DestinationClientDwhConfiguration +from dlt.common.destination.client import DestinationClientDwhConfiguration from dlt.common.schema.schema import Schema from dlt.common.typing import TDataItem from dlt.common.utils import uniq_id diff --git a/tests/load/pipeline/test_drop.py b/tests/load/pipeline/test_drop.py index 330f2606ff..f1f78d4e9a 100644 --- a/tests/load/pipeline/test_drop.py +++ b/tests/load/pipeline/test_drop.py @@ -6,7 +6,7 @@ import pytest import dlt -from dlt.common.destination.reference import JobClientBase +from dlt.common.destination.client import JobClientBase from dlt.extract import DltResource from dlt.common.utils import uniq_id from dlt.pipeline import helpers, state_sync, Pipeline diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index b998b78471..42a5385134 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -9,7 +9,7 @@ from dlt.common import json, sleep from dlt.common.pipeline import SupportsPipeline from dlt.common.destination import Destination -from dlt.common.destination.reference import WithStagingDataset +from dlt.common.destination.client import WithStagingDataset from dlt.common.schema.schema import Schema from dlt.common.schema.typing import VERSION_TABLE_NAME from dlt.common.schema.utils import new_table diff --git a/tests/load/pipeline/test_postgres.py b/tests/load/pipeline/test_postgres.py index e09582f8a8..7fa051ef52 100644 --- a/tests/load/pipeline/test_postgres.py +++ b/tests/load/pipeline/test_postgres.py @@ -5,8 +5,7 @@ import pytest import dlt -from dlt.common.destination.reference import Destination -from dlt.common.schema.exceptions import CannotCoerceColumnException +from dlt.common.destination import Destination from dlt.common.schema.schema import Schema from dlt.common.utils import uniq_id diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index b78306210f..6ff167ba78 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -11,7 +11,7 @@ from dlt.common.schema.utils import normalize_table_identifiers from dlt.common.utils import uniq_id from dlt.common.destination.exceptions import DestinationUndefinedEntity -from dlt.common.destination.reference import WithStateSync +from dlt.common.destination.client import WithStateSync from dlt.load import Load from dlt.pipeline.exceptions import SqlClientNotAvailable diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index e27000c841..3c37e54ac6 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -11,7 +11,8 @@ from dlt.common.storages.configuration import FilesystemConfiguration from dlt.common.storages.load_package import TPackageJobState from dlt.common.storages.load_storage import JobFileFormatUnsupported -from dlt.common.destination.reference import RunnableLoadJob, AnyDestination +from dlt.common.destination import AnyDestination +from dlt.common.destination.client import RunnableLoadJob from dlt.common.schema.utils import ( fill_hints_from_parent_and_clone_table, get_nested_tables, diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 6f699436b3..7b7952ce64 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -27,7 +27,7 @@ ) from dlt.destinations.job_client_impl import SqlJobClientBase -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( StateInfo, WithStagingDataset, DestinationClientConfiguration, diff --git a/tests/load/test_jobs.py b/tests/load/test_jobs.py index 69f5fb9ddc..3248508436 100644 --- a/tests/load/test_jobs.py +++ b/tests/load/test_jobs.py @@ -1,6 +1,6 @@ import pytest -from dlt.common.destination.reference import RunnableLoadJob +from dlt.common.destination.client import RunnableLoadJob from dlt.common.destination.exceptions import DestinationTerminalException from dlt.destinations.job_impl import FinalizedLoadJob diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index ee48222da9..e0a54cf21f 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -683,6 +683,8 @@ def test_recover_on_explicit_tx(client: SqlJobClientBase) -> None: "COMMIT;", ] # cannot insert NULL value + with pytest.raises(DatabaseTerminalException): + client.sql_client.execute_many(statements[1:2]) with pytest.raises(DatabaseTerminalException): client.sql_client.execute_many(statements) # assert derives_from_class_of_name(term_ex.value.dbapi_exception, "IntegrityError") diff --git a/tests/load/utils.py b/tests/load/utils.py index 5660202ec3..5e61292825 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -30,18 +30,17 @@ CredentialsConfiguration, GcpOAuthCredentialsWithoutDefaults, ) -from dlt.common.destination.reference import ( +from dlt.common.destination.client import ( DestinationClientDwhConfiguration, JobClientBase, RunnableLoadJob, LoadJob, DestinationClientStagingConfiguration, - TDestinationReferenceArg, WithStagingDataset, DestinationCapabilitiesContext, ) -from dlt.common.destination import TLoaderFileFormat, Destination -from dlt.common.destination.reference import DEFAULT_FILE_LAYOUT +from dlt.common.destination import TLoaderFileFormat, Destination, TDestinationReferenceArg +from dlt.common.destination.client import DEFAULT_FILE_LAYOUT from dlt.common.data_writers import DataWriter from dlt.common.pipeline import PipelineContext from dlt.common.schema import TTableSchemaColumns, Schema diff --git a/tests/pipeline/test_pipeline_state.py b/tests/pipeline/test_pipeline_state.py index a2134dba33..dd399e00cd 100644 --- a/tests/pipeline/test_pipeline_state.py +++ b/tests/pipeline/test_pipeline_state.py @@ -16,7 +16,8 @@ from dlt.common import pipeline as state_module from dlt.common.storages.load_package import TPipelineStateDoc from dlt.common.utils import uniq_id -from dlt.common.destination.reference import Destination, StateInfo +from dlt.common.destination import Destination +from dlt.common.destination.client import StateInfo from dlt.common.validation import validate_dict from dlt.destinations.utils import get_pipeline_state_query_columns