Skip to content

Commit

Permalink
refac db-connection (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
vitorbellini authored Jul 13, 2024
1 parent 543f271 commit 8015c4e
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 50 deletions.
20 changes: 12 additions & 8 deletions fastetl/custom_functions/fast_etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ def copy_db_to_db(
"""

# validate connections
source = SourceConnection(**source)
destination = DestinationConnection(**destination)
source = SourceConnection(source)
destination = DestinationConnection(destination)

# create table if not exists in destination db
if not source.query:
Expand Down Expand Up @@ -621,13 +621,17 @@ def _divide_chunks(l, n):
if copy_table_comments:
_copy_table_comments(
source=SourceConnection(
conn_id=source_conn_id,
schema=source_table_name.split(".")[0],
table=source_table_name.split(".")[1],
{
"conn_id": source_conn_id,
"schema": source_table_name.split(".")[0],
"table": source_table_name.split(".")[1],
}
),
destination=DestinationConnection(
conn_id=destination_conn_id,
schema=dest_table_name.split(".")[0],
table=dest_table_name.split(".")[1],
{
"conn_id": destination_conn_id,
"schema": dest_table_name.split(".")[0],
"table": dest_table_name.split(".")[1],
}
),
)
39 changes: 18 additions & 21 deletions fastetl/custom_functions/utils/db_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,24 +86,21 @@ class SourceConnection:
conn_type (str): Connection type/provider.
"""

def __init__(
self,
conn_id: str,
schema: str = None,
table: str = None,
query: str = None,
):
if not conn_id:
def __init__(self, params: dict):
self.conn_id = params.get("conn_id", None)
self.schema = params.get("schema", None)
self.table = params.get("table", None)
self.query = params.get("query", None)

if not self.conn_id:
raise ValueError("conn_id argument cannot be empty")
if not query and not (schema or table):
if not self.query and not (
self.schema or self.table
):
raise ValueError("must provide either schema and table or query")

self.conn_id = conn_id
self.schema = schema
self.table = table
self.query = query
self.conn_type = get_conn_type(conn_id)
conn_values = BaseHook.get_connection(conn_id)
self.conn_type = get_conn_type(self.conn_id)
conn_values = BaseHook.get_connection(self.conn_id)
self.conn_database = conn_values.schema


Expand All @@ -124,12 +121,12 @@ class DestinationConnection:
conn_type (str): Connection type/provider.
"""

def __init__(self, conn_id: str, schema: str, table: str):
self.conn_id = conn_id
self.schema = schema
self.table = table
self.conn_type = get_conn_type(conn_id)
conn_values = BaseHook.get_connection(conn_id)
def __init__(self, params: dict):
self.conn_id = params.get("conn_id", None)
self.schema = params.get("schema", None)
self.table = params.get("table", None)
self.conn_type = get_conn_type(self.conn_id)
conn_values = BaseHook.get_connection(self.conn_id)
self.conn_database = conn_values.schema


Expand Down
8 changes: 4 additions & 4 deletions fastetl/hooks/db_to_db_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ def incremental_copy(
copy_table_comments: bool = False,
):
sync_db_2_db(
source_conn_id=self.source.conn_id,
destination_conn_id=self.destination.conn_id,
source_schema=self.source.schema,
source_conn_id=self.source["conn_id"],
destination_conn_id=self.destination["conn_id"],
source_schema=self.source["schema"],
source_exc_schema=self.source.get("source_exc_schema", None),
source_exc_table=self.source.get("source_exc_table", None),
source_exc_column=self.source.get("source_exc_column", None),
select_sql=self.source.get("query", None),
destination_schema=self.destination.schema,
destination_schema=self.destination["schema"],
increment_schema=self.destination.get("increment_schema", None),
table=table,
date_column=date_column,
Expand Down
26 changes: 9 additions & 17 deletions fastetl/operators/db_to_db_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@
class DbToDbOperator(BaseOperator):
template_fields = ["source"]

@apply_defaults
def __init__(
self,
source: Dict[str, str],
Expand Down Expand Up @@ -127,6 +126,15 @@ def __init__(
self.key_column = key_column
self.since_datetime = since_datetime
self.sync_exclusions = sync_exclusions

# rename if schema_name is present
if source.get("schema_name", None):
source["schema"] = source.pop("schema_name")
if destination.get("schema_name", None):
destination["schema"] = destination.pop("schema_name")
self.source = source
self.destination = destination

# any value that needs to be the same for inlets and outlets
key = str(random.randint(10000000, 99999999))
if source.get("om_service", None):
Expand All @@ -135,22 +143,6 @@ def __init__(
self.outlets = [
OMEntity(entity=Table, fqn=self._get_fqn(destination), key=key)
]
# rename if schema_name is present
if source.get("schema_name", None):
source["schema"] = source.pop("schema_name")
if destination.get("schema_name", None):
destination["schema"] = destination.pop("schema_name")
# filter to keys accepted by DbToDbHook
self.source = {
key: source[key]
for key in ["conn_id", "schema", "table", "query"]
if key in source
}
self.destination = {
key: destination[key]
for key in ["conn_id", "schema", "table"]
if key in destination
}

def _get_fqn(self, data):
data["database"] = BaseHook.get_connection(data["conn_id"]).schema
Expand Down

0 comments on commit 8015c4e

Please sign in to comment.