Skip to content

Commit

Permalink
Merge pull request #103 from moj-analytical-services/develop
Browse files Browse the repository at this point in the history
Revert to AP working version
  • Loading branch information
mratford authored May 5, 2023
2 parents ff805fd + 74efb2e commit 20698b8
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 356 deletions.
8 changes: 0 additions & 8 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,6 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## v5.5.5 - 2023-04-26

- Another user_id fix and get region from output bucket

## v5.5.4 - 2023-04-26

- Fix user_id parsing in light of SSO

## v5.5.3 - 2023-03-06

- Fixed issue in create_temp_table
Expand Down
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,6 @@ pydb.create_temp_table()
print(wr.catalog.delete_database(name=temp_db_name))
```
### Setting the region
In order to run queries, Athena needs to output its results into a staging bucket in S3. The aws region passed to awswrangler needs to be the same as the region of that bucket. This is usually the same as that set by the `AWS_DEFAULT_REGION` set within your underlying environment. However, in cases of cross-region working, you can specify the region for Athena to access by setting `AWS_ATHENA_QUERY_REGION` as an environment variable.
# DEPRECATED
## Functions
Expand Down
280 changes: 19 additions & 261 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pydbtools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@

from .utils import s3_path_join # noqa: F401

__version__ = "5.5.5"
__version__ = "5.5.6"
41 changes: 8 additions & 33 deletions pydbtools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,21 @@
import sql_metadata
import inspect
import boto3
from botocore.exceptions import NoCredentialsError
from botocore.credentials import (
InstanceMetadataProvider,
InstanceMetadataFetcher,
)
from functools import reduce
import awswrangler as wr
import warnings


# Set pydbtool params - if you were so inclined to change them
bucket = os.getenv("ATHENA_QUERY_DUMP_BUCKET", "mojap-athena-query-dump")
try:
bucket_region = wr.s3.get_bucket_region(bucket)
except NoCredentialsError:
bucket_region = "eu-west-1"
bucket = "mojap-athena-query-dump"
temp_database_name_prefix = "mojap_de_temp_"
aws_default_region = os.getenv(
"AWS_ATHENA_QUERY_REGION",
os.getenv("AWS_DEFAULT_REGION", os.getenv("AWS_REGION", "eu-west-1")),
"AWS_DEFAULT_REGION", os.getenv("AWS_REGION", "eu-west-1")
)

if aws_default_region != bucket_region:
warnings.warn(
f"""
Your aws region {aws_default_region} is different from the bucket where
the query results are saved: {bucket_region}. You can change this for this session
by setting pydb.utils.aws_default_region = "{bucket_region}".
You should also set the environment variable:
AWS_ATHENA_QUERY_REGION = "{bucket_region}" to ensure the correct region is set.
"""
)


def s3_path_join(base: str, *urls: [str]):
return reduce(_s3_path_join, urls, base)
Expand Down Expand Up @@ -155,14 +138,6 @@ def replace_temp_database_name_reference(sql: str, database_name: str) -> str:
return "".join(new_query).strip()


def clean_user_id(user_id: str) -> str:
username = user_id.split(":")[-1]
if "@" in username:
username = username.split("@")[0]
username = username.replace("-", "_")
return username


def get_user_id_and_table_dir(
boto3_session=None, force_ec2: bool = False, region_name: str = None
) -> Tuple[str, str]:
Expand All @@ -174,16 +149,16 @@ def get_user_id_and_table_dir(

sts_client = boto3_session.client("sts")
sts_resp = sts_client.get_caller_identity()
user_id = clean_user_id(sts_resp["UserId"])
out_path = s3_path_join("s3://" + bucket, user_id)
out_path = s3_path_join("s3://" + bucket, sts_resp["UserId"])
if out_path[-1] != "/":
out_path += "/"

return (user_id, out_path)
return (sts_resp["UserId"], out_path)


def get_database_name_from_userid(clean_user_id: str) -> str:
unique_db_name = temp_database_name_prefix + clean_user_id
def get_database_name_from_userid(user_id: str) -> str:
unique_db_name = user_id.split(":")[-1].split("-", 1)[-1].replace("-", "_")
unique_db_name = temp_database_name_prefix + unique_db_name
return unique_db_name


Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool]
[tool.poetry]
name = "pydbtools"
version = "5.5.5"
version = "5.5.6"
description = "A python package to query data via amazon athena and bring it into a pandas df using aws-wrangler."
license = "MIT"
authors = ["Karik Isichei <[email protected]>"]
Expand All @@ -11,7 +11,7 @@ readme = "README.md"
python = ">=3.8,<3.11" # wrangler dependency
boto3 = ">=1.7.4"
sqlparse = "^0.4.4"
awswrangler = "^2.15.0"
awswrangler = "^2.12.0"
pyarrow = ">=5.0.0"
Jinja2 = ">=3.1.0"
sql-metadata = "^2.3.0"
Expand All @@ -20,7 +20,6 @@ arrow-pd-parser = "^1.3.7"
[tool.poetry.dev-dependencies]
pytest = ">=6.1"
toml = "^0.10"
moto = "^4.1.8"

[build-system]
requires = ["poetry>=0.12"]
Expand Down
32 changes: 0 additions & 32 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import pytest
import os
from moto import mock_s3
import boto3


# Get all SQL files in a dict
Expand All @@ -13,33 +11,3 @@ def sql_dict():
with open(os.path.join("tests/data/", fn)) as f:
sql_dict[fn.split(".")[0]] = "".join(f.readlines())
return sql_dict


@pytest.fixture(scope="function")
def aws_credentials():
"""Mocked AWS Credentials for moto."""
mocked_envs = [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SECURITY_TOKEN",
"AWS_SESSION_TOKEN",
]
for menv in mocked_envs:
os.environ[menv] = "testing"

yield # Allows us to close down envs on exit

for menv in mocked_envs:
del os.environ[menv]


@pytest.fixture(scope="function")
def s3(aws_credentials):
with mock_s3():
yield boto3.resource("s3", region_name="eu-west-1")


@pytest.fixture(scope="function")
def s3_client(aws_credentials):
with mock_s3():
yield boto3.client("s3", region_name="eu-west-1")
14 changes: 0 additions & 14 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import toml
import pytest


def test_set_module_values():
Expand Down Expand Up @@ -27,16 +26,3 @@ def test_pyproject_toml_matches_version():
with open("pyproject.toml") as f:
proj = toml.load(f)
assert pydb.__version__ == proj["tool"]["poetry"]["version"]


@pytest.mark.parametrize(
"test_input, expected",
[
("abcde:12345", "12345"),
("abcde:my-name", "my_name"),
("abcde:[email protected]", "my_name"),
],
)
def test_clean_user_id(test_input, expected):
import pydbtools as pydb
assert pydb.utils.clean_user_id(test_input) == expected

0 comments on commit 20698b8

Please sign in to comment.