Skip to content

Commit

Permalink
Merge pull request #9 from namiyousef/develop
Browse files Browse the repository at this point in the history
release v0.4.0
  • Loading branch information
namiyousef authored Oct 31, 2023
2 parents 07c9c5f + a141e48 commit 42f2e78
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 26 deletions.
2 changes: 1 addition & 1 deletion in_n_out_clients/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.3.0"
__version__ = "0.4.0"

from in_n_out_clients.main import InNOutClient
96 changes: 72 additions & 24 deletions in_n_out_clients/postgres_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import logging
from typing import List

Expand All @@ -19,6 +20,15 @@ class OnDataConflictFail(Exception):


class PostgresClient:
"""Client for interfacing with postgres databases.
:param username: database username
:param password: database password
:param host: database host
:param port: database port
:param database_name: database name
"""

def __init__(
self,
username: str,
Expand All @@ -38,7 +48,7 @@ def __init__(
f":{self.db_port}/{self.db_name}"
)
try:
self.engine, self.con = self.initialise_client()
self.engine = self.initialise_client()
except db.exc.OperationalError as operational_error:
raise ConnectionError(
"Could not connect to postgres client. "
Expand All @@ -47,26 +57,62 @@ def __init__(

def initialise_client(self):
self.engine = db.create_engine(self.db_uri)
self.con = self.engine.connect()
return self.engine, self.con
return self.engine

def query(self, query):
query_result = self.con.execute(query)
def query(self, query: str) -> pd.DataFrame:
"""Run a query against the databae.
:param query: query to run
:returns: dataframe of the query result (tested only for delect queries)
"""
with self.engine.connect() as con:
query_result = con.execute(db.text(query))
data = query_result.fetchall()
columns = query_result.keys()
df = pd.DataFrame(data, columns=columns)
columns = list(query_result.keys())

# TODO pandas hotfix: unable to understand date as a timevalue
dtype_mapping = None
for record in data:
indices_of_date_items = [
i
for i, item in enumerate(record)
if isinstance(item, datetime.date)
]
dtype_mapping = {
columns[i]: "datetime64[ns]" for i in indices_of_date_items
}
break

df = pd.DataFrame.from_records(data, columns=columns)

if dtype_mapping:
df = df.astype(dtype_mapping)
return df

def _write(
self,
table_name: str,
data,
data: pd.DataFrame,
on_data_conflict: str = "append",
on_asset_conflict: str = "append",
dataset_name: str | None = None,
data_conflict_properties: List[str] | None = None
# data_conflict_properties,
data_conflict_properties: List[str] | None = None,
):
"""Internal function that is used by `InNOutClient` as a universal
write entry.
:param table_name: name of the table to write to
:param data: dataframe to write
:param on_data_conflict: how to behave if some of the rows to
write already exist, defaults to "append"
:param on_asset_conflict: how to behave if the table already exists,
defaults to "append"
:param dataset_name: name of the dataset (postgres schema) that
table belongs to, defaults to None
:param data_conflict_properties: rows to check for conflicts.
Note: these must match existing constraints on the
table, defaults to None
"""
resp = self.write(
df=data,
table_name=table_name,
Expand Down Expand Up @@ -125,7 +171,7 @@ def _get_pg_datatypes(df):
try:
df.to_sql(
table_name,
self.con,
self.engine,
schema=dataset_name,
if_exists=on_asset_conflict,
index=False,
Expand Down Expand Up @@ -185,16 +231,6 @@ def _get_pg_datatypes(df):
"No conflicts found... proceeding with normal write process"
)"""

df.to_sql(
table_name,
self.con,
schema=dataset_name,
if_exists=on_asset_conflict,
index=False,
method="multi",
dtype=dtypes,
)

return {"status_code": 200, "msg": "successfully wrote data"}


Expand Down Expand Up @@ -231,6 +267,10 @@ def insert_with_conflict_resolution(
if on_data_conflict == ConflictResolutionStrategy.FAIL:
if num_results != len(data):
conn.rollback()
# TODO maybe can do a on_conflict_do_update query instead,
# then return the
# excluded values? E..g because at the moment the
# returned data is a misnomer!
raise OnDataConflictFail(
{
"msg": f"Found {len(data) - num_results} conflicting rows",
Expand Down Expand Up @@ -263,9 +303,15 @@ def postgres_fail():

df = pd.DataFrame(
{
"currency": ["EUR", "GBP"],
"date": ["2025-01-01", "2024-01-01"],
"value_in_pounds": [-1, 0.9],
"currency": ["EUR", "GBP", "AED", "EUR", "AED"],
"date": [
"2025-01-01",
"2024-01-01",
"2026-01-01",
"2017-01-01",
"2012-01-01",
],
"value_in_pounds": [-1.5, 0.9, 1, 20, -10],
}
).astype({"date": "datetime64[ns]"})
print(
Expand All @@ -278,6 +324,8 @@ def postgres_fail():
["currency", "date"],
)
)

client.query("select * from currency_history").info()
raise Exception()
df.to_sql(
"currency_history",
Expand Down
2 changes: 1 addition & 1 deletion requirements/core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ google-auth-oauthlib
psycopg2-binary
PyDrive2
pandas
SQLAlchemy==1.4.17
SQLAlchemy

0 comments on commit 42f2e78

Please sign in to comment.