From 7641bcf458bfed44911021e3e4259a08fb63fbb3 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Wed, 22 Jan 2025 14:39:24 +0100 Subject: [PATCH] refactor databricks direct loading --- dlt/destinations/impl/athena/athena.py | 6 ---- dlt/destinations/impl/bigquery/sql_client.py | 6 ---- .../impl/clickhouse/sql_client.py | 6 ---- .../impl/databricks/configuration.py | 29 ++++++++++----- .../impl/databricks/databricks.py | 35 ++++++++++++------- .../impl/databricks/sql_client.py | 13 ------- dlt/destinations/impl/dremio/sql_client.py | 6 ---- dlt/destinations/impl/duckdb/sql_client.py | 6 ---- dlt/destinations/impl/mssql/sql_client.py | 6 ---- dlt/destinations/impl/postgres/sql_client.py | 6 ---- dlt/destinations/impl/snowflake/sql_client.py | 6 ---- .../impl/sqlalchemy/db_api_client.py | 6 ---- dlt/destinations/job_client_impl.py | 1 - dlt/destinations/sql_client.py | 8 ----- dlt/pipeline/pipeline.py | 1 - .../load/pipeline/test_databricks_pipeline.py | 8 +++-- 16 files changed, 47 insertions(+), 102 deletions(-) diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index f47bce968d..c7e30aaf55 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -190,12 +190,6 @@ def close_connection(self) -> None: self._conn.close() self._conn = None - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @property def native_connection(self) -> Connection: return self._conn diff --git a/dlt/destinations/impl/bigquery/sql_client.py b/dlt/destinations/impl/bigquery/sql_client.py index 194b1594ea..6911fa5c1c 100644 --- a/dlt/destinations/impl/bigquery/sql_client.py +++ b/dlt/destinations/impl/bigquery/sql_client.py @@ -112,12 +112,6 @@ def close_connection(self) -> None: self._client.close() self._client = None - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @contextmanager @raise_database_error def begin_transaction(self) -> Iterator[DBTransaction]: diff --git a/dlt/destinations/impl/clickhouse/sql_client.py b/dlt/destinations/impl/clickhouse/sql_client.py index 7c1847fa3c..a6c4ee0458 100644 --- a/dlt/destinations/impl/clickhouse/sql_client.py +++ b/dlt/destinations/impl/clickhouse/sql_client.py @@ -99,12 +99,6 @@ def open_connection(self) -> clickhouse_driver.dbapi.connection.Connection: self._conn = clickhouse_driver.connect(dsn=self.credentials.to_native_representation()) return self._conn - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @raise_open_connection_error def close_connection(self) -> None: if self._conn: diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index ad1dc397a2..85eaaa4097 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -15,7 +15,6 @@ class DatabricksCredentials(CredentialsConfiguration): catalog: str = None server_hostname: str = None http_path: str = None - direct_load: bool = False access_token: Optional[TSecretStrValue] = None client_id: Optional[TSecretStrValue] = None client_secret: Optional[TSecretStrValue] = None @@ -38,14 +37,19 @@ class DatabricksCredentials(CredentialsConfiguration): def on_resolved(self) -> None: if not ((self.client_id and self.client_secret) or self.access_token): - # databricks authentication: get context config - from databricks.sdk import WorkspaceClient + try: + # attempt notebook context authentication + from databricks.sdk import WorkspaceClient - w = WorkspaceClient() - notebook_context = w.dbutils.notebook.entry_point.getDbutils().notebook().getContext() - self.access_token = notebook_context.apiToken().getOrElse(None) + w = WorkspaceClient() + notebook_context = ( + w.dbutils.notebook.entry_point.getDbutils().notebook().getContext() + ) + self.access_token = notebook_context.apiToken().getOrElse(None) - self.server_hostname = notebook_context.browserHostName().getOrElse(None) + self.server_hostname = notebook_context.browserHostName().getOrElse(None) + except Exception: + pass if not self.access_token or not self.server_hostname: raise ConfigurationValueError( @@ -54,8 +58,6 @@ def on_resolved(self) -> None: " and the server_hostname." ) - self.direct_load = True - def to_connector_params(self) -> Dict[str, Any]: conn_params = dict( catalog=self.catalog, @@ -83,6 +85,15 @@ class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration "If set, credentials with given name will be used in copy command" is_staging_external_location: bool = False """If true, the temporary credentials are not propagated to the COPY command""" + staging_volume_name: Optional[str] = None + """Name of the Databricks managed volume for temporary storage, e.g., ... Defaults to '_dlt_temp_load_volume' if not set.""" + + def on_resolved(self): + if self.staging_volume_name and self.staging_volume_name.count(".") != 2: + raise ConfigurationValueError( + f"Invalid staging_volume_name format: {self.staging_volume_name}. Expected format" + " is '..'." + ) def __str__(self) -> str: """Return displayable destination location""" diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index be1697cee0..9e6e445a3f 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -59,7 +59,7 @@ def run(self) -> None: # decide if this is a local file or a staged file is_local_file = not ReferenceFollowupJobRequest.is_reference_job(self._file_path) - if is_local_file and self._job_client.config.credentials.direct_load: + if is_local_file: # local file by uploading to a temporary volume on Databricks from_clause, file_name = self._handle_local_file_upload(self._file_path) credentials_clause = "" @@ -112,18 +112,27 @@ def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str]: ) file_name = FileStorage.get_file_name_from_file_path(local_file_path) - file_format = "" - if file_name.endswith(".parquet"): - file_format = "parquet" - elif file_name.endswith(".jsonl"): - file_format = "jsonl" - else: - return "", file_name - - volume_path = f"/Volumes/{self._sql_client.database_name}/{self._sql_client.dataset_name}/{self._sql_client.volume_name}/{time.time_ns()}" - volume_file_name = ( # replace file_name for random hex code - databricks loading fails when file_name starts with - or . - f"{uniq_id()}.{file_format}" - ) + volume_file_name = file_name + if file_name.startswith(("_", ".")): + volume_file_name = ( + "valid" + file_name + ) # databricks loading fails when file_name starts with - or . + + volume_catalog = self._sql_client.database_name + volume_database = self._sql_client.dataset_name + volume_name = "_dlt_staging_load_volume" + + # create staging volume name + fully_qualified_volume_name = f"{volume_catalog}.{volume_database}.{volume_name}" + if self._job_client.config.staging_volume_name: + fully_qualified_volume_name = self._job_client.config.staging_volume_name + volume_catalog, volume_database, volume_name = fully_qualified_volume_name.split(".") + + self._sql_client.execute_sql(f""" + CREATE VOLUME IF NOT EXISTS {fully_qualified_volume_name} + """) + + volume_path = f"/Volumes/{volume_catalog}/{volume_database}/{volume_name}/{time.time_ns()}" volume_file_path = f"{volume_path}/{volume_file_name}" with open(local_file_path, "rb") as f: diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index a9e880a56e..9f695b9d6e 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -63,7 +63,6 @@ def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]: class DatabricksSqlClient(SqlClientBase[DatabricksSqlConnection], DBTransaction): dbapi: ClassVar[DBApi] = databricks_lib - volume_name: str = "_dlt_temp_load_volume" def __init__( self, @@ -103,18 +102,6 @@ def close_connection(self) -> None: self._conn.close() self._conn = None - def create_volume(self) -> None: - self.execute_sql(f""" - CREATE VOLUME IF NOT EXISTS {self.fully_qualified_dataset_name()}.{self.volume_name} - """) - - def drop_volume(self) -> None: - if not self._conn: - self.open_connection() - self.execute_sql(f""" - DROP VOLUME IF EXISTS {self.fully_qualified_dataset_name()}.{self.volume_name} - """) - @contextmanager def begin_transaction(self) -> Iterator[DBTransaction]: # Databricks does not support transactions diff --git a/dlt/destinations/impl/dremio/sql_client.py b/dlt/destinations/impl/dremio/sql_client.py index d8c509bf18..030009c74b 100644 --- a/dlt/destinations/impl/dremio/sql_client.py +++ b/dlt/destinations/impl/dremio/sql_client.py @@ -64,12 +64,6 @@ def close_connection(self) -> None: self._conn.close() self._conn = None - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @contextmanager @raise_database_error def begin_transaction(self) -> Iterator[DBTransaction]: diff --git a/dlt/destinations/impl/duckdb/sql_client.py b/dlt/destinations/impl/duckdb/sql_client.py index ba8572ede8..ee73965df6 100644 --- a/dlt/destinations/impl/duckdb/sql_client.py +++ b/dlt/destinations/impl/duckdb/sql_client.py @@ -95,12 +95,6 @@ def close_connection(self) -> None: self.credentials.return_conn(self._conn) self._conn = None - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @contextmanager @raise_database_error def begin_transaction(self) -> Iterator[DBTransaction]: diff --git a/dlt/destinations/impl/mssql/sql_client.py b/dlt/destinations/impl/mssql/sql_client.py index 467f0c2b6c..9f05b88bb5 100644 --- a/dlt/destinations/impl/mssql/sql_client.py +++ b/dlt/destinations/impl/mssql/sql_client.py @@ -70,12 +70,6 @@ def close_connection(self) -> None: self._conn.close() self._conn = None - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @contextmanager def begin_transaction(self) -> Iterator[DBTransaction]: try: diff --git a/dlt/destinations/impl/postgres/sql_client.py b/dlt/destinations/impl/postgres/sql_client.py index b76ca92353..a97c8511f1 100644 --- a/dlt/destinations/impl/postgres/sql_client.py +++ b/dlt/destinations/impl/postgres/sql_client.py @@ -58,12 +58,6 @@ def close_connection(self) -> None: self._conn.close() self._conn = None - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @contextmanager def begin_transaction(self) -> Iterator[DBTransaction]: try: diff --git a/dlt/destinations/impl/snowflake/sql_client.py b/dlt/destinations/impl/snowflake/sql_client.py index 56e939e456..22e27ea48b 100644 --- a/dlt/destinations/impl/snowflake/sql_client.py +++ b/dlt/destinations/impl/snowflake/sql_client.py @@ -63,12 +63,6 @@ def close_connection(self) -> None: self._conn.close() self._conn = None - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @contextmanager def begin_transaction(self) -> Iterator[DBTransaction]: try: diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index 915aee7eae..27c4f2f1f9 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -171,12 +171,6 @@ def close_connection(self) -> None: self._current_connection = None self._current_transaction = None - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @property def native_connection(self) -> Connection: if not self._current_connection: diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 234493104d..888c80c006 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -176,7 +176,6 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: self.sql_client.create_dataset() elif truncate_tables: self.sql_client.truncate_tables(*truncate_tables) - self.sql_client.create_volume() def is_storage_initialized(self) -> bool: return self.sql_client.has_dataset() diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index 56d11e143c..345afff18e 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -87,14 +87,6 @@ def close_connection(self) -> None: def begin_transaction(self) -> ContextManager[DBTransaction]: pass - @abstractmethod - def create_volume(self) -> None: - pass - - @abstractmethod - def drop_volume(self) -> None: - pass - def __getattr__(self, name: str) -> Any: # pass unresolved attrs to native connections if not self.native_connection: diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 5c66a60498..74466a09e4 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -608,7 +608,6 @@ def load( runner.run_pool(load_step.config, load_step) info: LoadInfo = self._get_step_info(load_step) - self.sql_client().drop_volume() self.first_run = False return info except Exception as l_ex: diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py index 71c901fb16..5ff1cc2ca2 100644 --- a/tests/load/pipeline/test_databricks_pipeline.py +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -220,6 +220,11 @@ def test_databricks_auth_token(destination_config: DestinationTestConfiguration) assert len(rows) == 3 +# TODO: test config staging_volume_name on_resolve +# TODO: modify the DestinationTestConfiguration +# TODO: add test databricks credentials default auth error +# TODO: test on notebook +# TODO: check that volume doesn't block schema drop @pytest.mark.parametrize( "destination_config", destinations_configs(default_sql_configs=True, subset=("databricks",)), @@ -229,9 +234,6 @@ def test_databricks_direct_load(destination_config: DestinationTestConfiguration os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" - # direct_load - os.environ["DESTINATION__DATABRICKS__CREDENTIALS__DIRECT_LOAD"] = "True" - bricks = databricks() config = bricks.configuration(None, accept_partial=True) assert config.credentials.access_token