Skip to content

Commit

Permalink
#11 refactors code
Browse files Browse the repository at this point in the history
  • Loading branch information
namiyousef committed Nov 30, 2023
1 parent 7571fbd commit 48af4a0
Showing 1 changed file with 37 additions and 11 deletions.
48 changes: 37 additions & 11 deletions in_n_out_clients/postgres_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,26 +244,31 @@ def _get_pg_datatypes(df):
return {"status_code": 200, "msg": "successfully wrote data"}


def insert_with_conflict_resolution(
table, conn, keys, data_iter, on_data_conflict, data_conflict_properties
def _generate_default_cols_when_partial_data(
conn, table_name: str, columns_in_data: List[str]
):
"""Internal function to add default values to non-nullable columns without
defaults from the target table in case where write data has partial
columns.
:param conn: SQLAlchemy connection
:param table_name: name of the target table
:param columns_in_data: these are the columns present in the `write`
data, the function will ignore these columns
:return: a list containing the metadata (to be directly used by
`sqlalchemy.Column`) of all columns from the target table that
by definition have no defaults and are not nullable with data-
type inferred defaults (excludes columns present in data)
"""
from sqlalchemy import Column, inspect
from sqlalchemy.dialects.postgresql import insert

data = [dict(zip(keys, row, strict=True)) for row in data_iter]

sqlalchemy_table = table.table
table_name = sqlalchemy_table.name

insert_statement = insert(sqlalchemy_table).values(data)

all_columns = inspect(conn).get_columns(table_name)
non_nullable_cols_with_no_default = []
for column_metadata in all_columns:
column_name = column_metadata["name"]
is_nullable = column_metadata["nullable"]
has_default = column_metadata["default"]
if column_name in keys:
if column_name in columns_in_data:
continue
if not is_nullable and has_default is None:
logger.debug(
Expand Down Expand Up @@ -293,7 +298,28 @@ def insert_with_conflict_resolution(

sqlalchemy_column = Column(**column_metadata)
non_nullable_cols_with_no_default.append(sqlalchemy_column)
return non_nullable_cols_with_no_default


def insert_with_conflict_resolution(
table, conn, keys, data_iter, on_data_conflict, data_conflict_properties
):
from sqlalchemy.dialects.postgresql import insert

data = [dict(zip(keys, row, strict=True)) for row in data_iter]

sqlalchemy_table = table.table
table_name = sqlalchemy_table.name

insert_statement = insert(sqlalchemy_table).values(data)

# -- find columns that need to be added to insert query in case of
# partial data
non_nullable_cols_with_no_default = (
_generate_default_cols_when_partial_data(
conn=conn, table_name=table_name, columns_in_data=keys
)
)
for sqlalchemy_column in non_nullable_cols_with_no_default:
sqlalchemy_table.append_column(sqlalchemy_column)

Expand Down

0 comments on commit 48af4a0

Please sign in to comment.