Skip to content

Commit

Permalink
refactor databricks direct loading
Browse files Browse the repository at this point in the history
  • Loading branch information
donotpush committed Jan 22, 2025
1 parent 2bd0be0 commit 7641bcf
Show file tree
Hide file tree
Showing 16 changed files with 47 additions and 102 deletions.
6 changes: 0 additions & 6 deletions dlt/destinations/impl/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,6 @@ def close_connection(self) -> None:
self._conn.close()
self._conn = None

def create_volume(self) -> None:
pass

def drop_volume(self) -> None:
pass

@property
def native_connection(self) -> Connection:
return self._conn
Expand Down
6 changes: 0 additions & 6 deletions dlt/destinations/impl/bigquery/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,6 @@ def close_connection(self) -> None:
self._client.close()
self._client = None

def create_volume(self) -> None:
pass

def drop_volume(self) -> None:
pass

@contextmanager
@raise_database_error
def begin_transaction(self) -> Iterator[DBTransaction]:
Expand Down
6 changes: 0 additions & 6 deletions dlt/destinations/impl/clickhouse/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,6 @@ def open_connection(self) -> clickhouse_driver.dbapi.connection.Connection:
self._conn = clickhouse_driver.connect(dsn=self.credentials.to_native_representation())
return self._conn

def create_volume(self) -> None:
pass

def drop_volume(self) -> None:
pass

@raise_open_connection_error
def close_connection(self) -> None:
if self._conn:
Expand Down
29 changes: 20 additions & 9 deletions dlt/destinations/impl/databricks/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class DatabricksCredentials(CredentialsConfiguration):
catalog: str = None
server_hostname: str = None
http_path: str = None
direct_load: bool = False
access_token: Optional[TSecretStrValue] = None
client_id: Optional[TSecretStrValue] = None
client_secret: Optional[TSecretStrValue] = None
Expand All @@ -38,14 +37,19 @@ class DatabricksCredentials(CredentialsConfiguration):

def on_resolved(self) -> None:
if not ((self.client_id and self.client_secret) or self.access_token):
# databricks authentication: get context config
from databricks.sdk import WorkspaceClient
try:
# attempt notebook context authentication
from databricks.sdk import WorkspaceClient

w = WorkspaceClient()
notebook_context = w.dbutils.notebook.entry_point.getDbutils().notebook().getContext()
self.access_token = notebook_context.apiToken().getOrElse(None)
w = WorkspaceClient()
notebook_context = (
w.dbutils.notebook.entry_point.getDbutils().notebook().getContext()
)
self.access_token = notebook_context.apiToken().getOrElse(None)

self.server_hostname = notebook_context.browserHostName().getOrElse(None)
self.server_hostname = notebook_context.browserHostName().getOrElse(None)
except Exception:
pass

if not self.access_token or not self.server_hostname:
raise ConfigurationValueError(
Expand All @@ -54,8 +58,6 @@ def on_resolved(self) -> None:
" and the server_hostname."
)

self.direct_load = True

def to_connector_params(self) -> Dict[str, Any]:
conn_params = dict(
catalog=self.catalog,
Expand Down Expand Up @@ -83,6 +85,15 @@ class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration
"If set, credentials with given name will be used in copy command"
is_staging_external_location: bool = False
"""If true, the temporary credentials are not propagated to the COPY command"""
staging_volume_name: Optional[str] = None
"""Name of the Databricks managed volume for temporary storage, e.g., <catalog_name>.<database_name>.<volume_name>. Defaults to '_dlt_temp_load_volume' if not set."""

def on_resolved(self):
if self.staging_volume_name and self.staging_volume_name.count(".") != 2:
raise ConfigurationValueError(
f"Invalid staging_volume_name format: {self.staging_volume_name}. Expected format"
" is '<catalog_name>.<database_name>.<volume_name>'."
)

def __str__(self) -> str:
"""Return displayable destination location"""
Expand Down
35 changes: 22 additions & 13 deletions dlt/destinations/impl/databricks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def run(self) -> None:

# decide if this is a local file or a staged file
is_local_file = not ReferenceFollowupJobRequest.is_reference_job(self._file_path)
if is_local_file and self._job_client.config.credentials.direct_load:
if is_local_file:
# local file by uploading to a temporary volume on Databricks
from_clause, file_name = self._handle_local_file_upload(self._file_path)
credentials_clause = ""
Expand Down Expand Up @@ -112,18 +112,27 @@ def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str]:
)

file_name = FileStorage.get_file_name_from_file_path(local_file_path)
file_format = ""
if file_name.endswith(".parquet"):
file_format = "parquet"
elif file_name.endswith(".jsonl"):
file_format = "jsonl"
else:
return "", file_name

volume_path = f"/Volumes/{self._sql_client.database_name}/{self._sql_client.dataset_name}/{self._sql_client.volume_name}/{time.time_ns()}"
volume_file_name = ( # replace file_name for random hex code - databricks loading fails when file_name starts with - or .
f"{uniq_id()}.{file_format}"
)
volume_file_name = file_name
if file_name.startswith(("_", ".")):
volume_file_name = (
"valid" + file_name
) # databricks loading fails when file_name starts with - or .

volume_catalog = self._sql_client.database_name
volume_database = self._sql_client.dataset_name
volume_name = "_dlt_staging_load_volume"

# create staging volume name
fully_qualified_volume_name = f"{volume_catalog}.{volume_database}.{volume_name}"
if self._job_client.config.staging_volume_name:
fully_qualified_volume_name = self._job_client.config.staging_volume_name
volume_catalog, volume_database, volume_name = fully_qualified_volume_name.split(".")

self._sql_client.execute_sql(f"""
CREATE VOLUME IF NOT EXISTS {fully_qualified_volume_name}
""")

volume_path = f"/Volumes/{volume_catalog}/{volume_database}/{volume_name}/{time.time_ns()}"
volume_file_path = f"{volume_path}/{volume_file_name}"

with open(local_file_path, "rb") as f:
Expand Down
13 changes: 0 additions & 13 deletions dlt/destinations/impl/databricks/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]:

class DatabricksSqlClient(SqlClientBase[DatabricksSqlConnection], DBTransaction):
dbapi: ClassVar[DBApi] = databricks_lib
volume_name: str = "_dlt_temp_load_volume"

def __init__(
self,
Expand Down Expand Up @@ -103,18 +102,6 @@ def close_connection(self) -> None:
self._conn.close()
self._conn = None

def create_volume(self) -> None:
self.execute_sql(f"""
CREATE VOLUME IF NOT EXISTS {self.fully_qualified_dataset_name()}.{self.volume_name}
""")

def drop_volume(self) -> None:
if not self._conn:
self.open_connection()
self.execute_sql(f"""
DROP VOLUME IF EXISTS {self.fully_qualified_dataset_name()}.{self.volume_name}
""")

@contextmanager
def begin_transaction(self) -> Iterator[DBTransaction]:
# Databricks does not support transactions
Expand Down
6 changes: 0 additions & 6 deletions dlt/destinations/impl/dremio/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,6 @@ def close_connection(self) -> None:
self._conn.close()
self._conn = None

def create_volume(self) -> None:
pass

def drop_volume(self) -> None:
pass

@contextmanager
@raise_database_error
def begin_transaction(self) -> Iterator[DBTransaction]:
Expand Down
6 changes: 0 additions & 6 deletions dlt/destinations/impl/duckdb/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,6 @@ def close_connection(self) -> None:
self.credentials.return_conn(self._conn)
self._conn = None

def create_volume(self) -> None:
pass

def drop_volume(self) -> None:
pass

@contextmanager
@raise_database_error
def begin_transaction(self) -> Iterator[DBTransaction]:
Expand Down
6 changes: 0 additions & 6 deletions dlt/destinations/impl/mssql/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,6 @@ def close_connection(self) -> None:
self._conn.close()
self._conn = None

def create_volume(self) -> None:
pass

def drop_volume(self) -> None:
pass

@contextmanager
def begin_transaction(self) -> Iterator[DBTransaction]:
try:
Expand Down
6 changes: 0 additions & 6 deletions dlt/destinations/impl/postgres/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,6 @@ def close_connection(self) -> None:
self._conn.close()
self._conn = None

def create_volume(self) -> None:
pass

def drop_volume(self) -> None:
pass

@contextmanager
def begin_transaction(self) -> Iterator[DBTransaction]:
try:
Expand Down
6 changes: 0 additions & 6 deletions dlt/destinations/impl/snowflake/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,6 @@ def close_connection(self) -> None:
self._conn.close()
self._conn = None

def create_volume(self) -> None:
pass

def drop_volume(self) -> None:
pass

@contextmanager
def begin_transaction(self) -> Iterator[DBTransaction]:
try:
Expand Down
6 changes: 0 additions & 6 deletions dlt/destinations/impl/sqlalchemy/db_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,6 @@ def close_connection(self) -> None:
self._current_connection = None
self._current_transaction = None

def create_volume(self) -> None:
pass

def drop_volume(self) -> None:
pass

@property
def native_connection(self) -> Connection:
if not self._current_connection:
Expand Down
1 change: 0 additions & 1 deletion dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None:
self.sql_client.create_dataset()
elif truncate_tables:
self.sql_client.truncate_tables(*truncate_tables)
self.sql_client.create_volume()

def is_storage_initialized(self) -> bool:
return self.sql_client.has_dataset()
Expand Down
8 changes: 0 additions & 8 deletions dlt/destinations/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,6 @@ def close_connection(self) -> None:
def begin_transaction(self) -> ContextManager[DBTransaction]:
pass

@abstractmethod
def create_volume(self) -> None:
pass

@abstractmethod
def drop_volume(self) -> None:
pass

def __getattr__(self, name: str) -> Any:
# pass unresolved attrs to native connections
if not self.native_connection:
Expand Down
1 change: 0 additions & 1 deletion dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,6 @@ def load(
runner.run_pool(load_step.config, load_step)
info: LoadInfo = self._get_step_info(load_step)

self.sql_client().drop_volume()
self.first_run = False
return info
except Exception as l_ex:
Expand Down
8 changes: 5 additions & 3 deletions tests/load/pipeline/test_databricks_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,11 @@ def test_databricks_auth_token(destination_config: DestinationTestConfiguration)
assert len(rows) == 3


# TODO: test config staging_volume_name on_resolve
# TODO: modify the DestinationTestConfiguration
# TODO: add test databricks credentials default auth error
# TODO: test on notebook
# TODO: check that volume doesn't block schema drop
@pytest.mark.parametrize(
"destination_config",
destinations_configs(default_sql_configs=True, subset=("databricks",)),
Expand All @@ -229,9 +234,6 @@ def test_databricks_direct_load(destination_config: DestinationTestConfiguration
os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = ""
os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = ""

# direct_load
os.environ["DESTINATION__DATABRICKS__CREDENTIALS__DIRECT_LOAD"] = "True"

bricks = databricks()
config = bricks.configuration(None, accept_partial=True)
assert config.credentials.access_token
Expand Down

0 comments on commit 7641bcf

Please sign in to comment.