From 9cbb8e76d5109a0193a39d4114908a2096b841fc Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Fri, 17 Nov 2023 10:45:01 -0600 Subject: [PATCH 01/12] first change --- sdv/datasets/demo.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sdv/datasets/demo.py b/sdv/datasets/demo.py index f92639863..954f93d43 100644 --- a/sdv/datasets/demo.py +++ b/sdv/datasets/demo.py @@ -21,6 +21,8 @@ LOGGER = logging.getLogger(__name__) BUCKET = 'sdv-demo-datasets' BUCKET_URL = 'https://sdv-demo-datasets.s3.amazonaws.com' +ACCESS_KEY = None +SECRET_ACCESS_KEY = None METADATA_FILENAME = 'metadata.json' @@ -162,7 +164,12 @@ def get_available_demos(modality): * If ``modality`` is not ``'single_table'``, ``'multi_table'`` or ``'sequential'``. """ _validate_modalities(modality) - client = boto3.client('s3', config=Config(signature_version=UNSIGNED)) + client = boto3.client( + 's3', + aws_access_key_id=ACCESS_KEY, + aws_secret_access_key=SECRET_ACCESS_KEY, + config=Config(signature_version=UNSIGNED) + ) tables_info = defaultdict(list) for item in client.list_objects(Bucket=BUCKET)['Contents']: dataset_modality, dataset = item['Key'].split('/', 1) From c55d26f7dc2519c4c15d206717128ddefdd93782 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Fri, 17 Nov 2023 11:31:42 -0600 Subject: [PATCH 02/12] update config --- sdv/datasets/demo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdv/datasets/demo.py b/sdv/datasets/demo.py index 954f93d43..f10d49efe 100644 --- a/sdv/datasets/demo.py +++ b/sdv/datasets/demo.py @@ -23,6 +23,7 @@ BUCKET_URL = 'https://sdv-demo-datasets.s3.amazonaws.com' ACCESS_KEY = None SECRET_ACCESS_KEY = None +CONFIG = Config(signature_version=UNSIGNED) METADATA_FILENAME = 'metadata.json' @@ -168,7 +169,7 @@ def get_available_demos(modality): 's3', aws_access_key_id=ACCESS_KEY, aws_secret_access_key=SECRET_ACCESS_KEY, - config=Config(signature_version=UNSIGNED) + config=CONFIG ) tables_info = defaultdict(list) for item in client.list_objects(Bucket=BUCKET)['Contents']: From 40645d1090f7966ae40d352d70d1f74a6681d0d0 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Fri, 17 Nov 2023 15:35:42 -0600 Subject: [PATCH 03/12] add new methods --- sdv/datasets/demo.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/sdv/datasets/demo.py b/sdv/datasets/demo.py index f10d49efe..9621bb60c 100644 --- a/sdv/datasets/demo.py +++ b/sdv/datasets/demo.py @@ -19,11 +19,13 @@ from sdv.metadata.single_table import SingleTableMetadata LOGGER = logging.getLogger(__name__) +IS_PRIVATE_BUCKET = False BUCKET = 'sdv-demo-datasets' BUCKET_URL = 'https://sdv-demo-datasets.s3.amazonaws.com' ACCESS_KEY = None SECRET_ACCESS_KEY = None CONFIG = Config(signature_version=UNSIGNED) +REGION_NAME = None METADATA_FILENAME = 'metadata.json' @@ -41,11 +43,37 @@ def _validate_output_folder(output_folder_name): ) +def _get_data_from_private_bucket(object_key): + session = boto3.Session( + aws_access_key_id=ACCESS_KEY, + aws_secret_access_key=SECRET_ACCESS_KEY, + region_name=REGION_NAME, + ) + s3 = session.client('s3') + response = s3.get_object(Bucket=BUCKET, Key=object_key) + file_content = response['Body'].read() + + return file_content + + +def _get_data_from_public_bucket(object_key): + public_url = f'https://{BUCKET}.s3.amazonaws.com/{object_key}' + response = urllib.request.urlopen(public_url) + file_content = response.read() + + return file_content + + def _download(modality, dataset_name): dataset_url = f'{BUCKET_URL}/{modality.upper()}/{dataset_name}.zip' + object_key = f'{modality.upper()}/{dataset_name}.zip' LOGGER.info(f'Downloading dataset {dataset_name} from {dataset_url}') try: - response = urllib.request.urlopen(dataset_url) + if IS_PRIVATE_BUCKET: + file_content = _get_data_from_private_bucket(object_key) + else: + file_content = _get_data_from_public_bucket(object_key) + except urllib.error.HTTPError: raise ValueError( f"Invalid dataset name '{dataset_name}'. " @@ -53,7 +81,7 @@ def _download(modality, dataset_name): "use 'get_available_demos' to get a list of demo datasets." ) - return io.BytesIO(response.read()) + return io.BytesIO(file_content) def _extract_data(bytes_io, output_folder_name): From 6fa88cd7abed30622866eaa47a833427aee8e18e Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Fri, 17 Nov 2023 15:35:51 -0600 Subject: [PATCH 04/12] tests --- tests/unit/datasets/test_demo.py | 72 +++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/tests/unit/datasets/test_demo.py b/tests/unit/datasets/test_demo.py index f14672ffc..d8a3625d7 100644 --- a/tests/unit/datasets/test_demo.py +++ b/tests/unit/datasets/test_demo.py @@ -5,7 +5,9 @@ import pandas as pd import pytest -from sdv.datasets.demo import download_demo, get_available_demos +from sdv.datasets.demo import ( + _download, _get_data_from_private_bucket, _get_data_from_public_bucket, download_demo, + get_available_demos) def test_download_demo_invalid_modality(): @@ -64,6 +66,74 @@ def test_download_demo_single_table(tmpdir): assert metadata.to_dict() == expected_metadata_dict +@patch('boto3.Session') +@patch('sdv.datasets.demo.ACCESS_KEY', 'access_key') +@patch('sdv.datasets.demo.SECRET_ACCESS_KEY', 'secret_access_key') +@patch('sdv.datasets.demo.REGION_NAME', 'region_name') +@patch('sdv.datasets.demo.BUCKET', 'bucket') +def test__get_data_from_private_bucket(session_mock): + """Test the ``_get_data_from_private_bucket`` method.""" + # Setup + session_mock.return_value.client.return_value.get_object.return_value = { + 'Body': MagicMock(read=MagicMock(return_value=b'')) + } + + # Run + _get_data_from_private_bucket('object_key') + + # Assert + session_mock.assert_called_once_with( + aws_access_key_id='access_key', aws_secret_access_key='secret_access_key', + region_name='region_name' + ) + session_mock.return_value.client.assert_called_once_with('s3') + session_mock.return_value.client.return_value.get_object.assert_called_once_with( + Bucket='bucket', Key='object_key' + ) + + +@patch('urllib.request.urlopen') +@patch('sdv.datasets.demo.BUCKET', 'bucket') +def test__get_data_from_public_bucket(url_open_mock): + """Test the ``_get_data_from_public_bucket`` method.""" + # Setup + url_open_mock.return_value.read.return_value = b'' + + # Run + _get_data_from_public_bucket('object_key') + + # Assert + url_open_mock.assert_called_once_with('https://bucket.s3.amazonaws.com/object_key') + url_open_mock.return_value.read.assert_called_once_with() + + +@patch('sdv.datasets.demo._get_data_from_public_bucket') +def test__download_public_bucket(mock_get_data_from_public_bucket): + """Test the ``_download`` method when the bucket is public.""" + # Setup + mock_get_data_from_public_bucket.return_value = b'' + + # Run + _download('single_table', 'ring') + + # Assert + mock_get_data_from_public_bucket.assert_called_once_with('SINGLE_TABLE/ring.zip') + + +@patch('sdv.datasets.demo.IS_PRIVATE_BUCKET', True) +@patch('sdv.datasets.demo._get_data_from_private_bucket') +def test__download_private_bucket(mock_get_data_from_private_bucket): + """Test the ``_download`` method when the bucket is private.""" + # Setup + mock_get_data_from_private_bucket.return_value = b'' + + # Run + _download('single_table', 'ring') + + # Assert + mock_get_data_from_private_bucket.assert_called_once_with('SINGLE_TABLE/ring.zip') + + def test_download_demo_single_table_no_output_folder(): """Test it can download a single table dataset when no output folder is passed.""" # Run From f62448a44a96669c1ad57dc7ced9d3d293155fed Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 20 Nov 2023 17:00:21 -0600 Subject: [PATCH 05/12] update def --- sdv/datasets/demo.py | 49 +++++++++++++++++--------------------------- 1 file changed, 19 insertions(+), 30 deletions(-) diff --git a/sdv/datasets/demo.py b/sdv/datasets/demo.py index 9621bb60c..1456f0e98 100644 --- a/sdv/datasets/demo.py +++ b/sdv/datasets/demo.py @@ -4,7 +4,6 @@ import json import logging import os -import urllib.request from collections import defaultdict from pathlib import Path from zipfile import ZipFile @@ -22,10 +21,6 @@ IS_PRIVATE_BUCKET = False BUCKET = 'sdv-demo-datasets' BUCKET_URL = 'https://sdv-demo-datasets.s3.amazonaws.com' -ACCESS_KEY = None -SECRET_ACCESS_KEY = None -CONFIG = Config(signature_version=UNSIGNED) -REGION_NAME = None METADATA_FILENAME = 'metadata.json' @@ -43,25 +38,20 @@ def _validate_output_folder(output_folder_name): ) -def _get_data_from_private_bucket(object_key): +def _get_data_from_bucket(object_key): + access_key = os.environ.get('AWS_ACCESS_KEY_ID') + secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY') + region = os.environ.get('AWS_REGION') session = boto3.Session( - aws_access_key_id=ACCESS_KEY, - aws_secret_access_key=SECRET_ACCESS_KEY, - region_name=REGION_NAME, + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=region, ) - s3 = session.client('s3') - response = s3.get_object(Bucket=BUCKET, Key=object_key) - file_content = response['Body'].read() - - return file_content - - -def _get_data_from_public_bucket(object_key): - public_url = f'https://{BUCKET}.s3.amazonaws.com/{object_key}' - response = urllib.request.urlopen(public_url) - file_content = response.read() + signature_version = 's3v4' if access_key else UNSIGNED + s3 = session.client('s3', config=Config(signature_version=signature_version)) - return file_content + response = s3.get_object(Bucket=BUCKET, Key=object_key) + return response['Body'].read() def _download(modality, dataset_name): @@ -69,12 +59,8 @@ def _download(modality, dataset_name): object_key = f'{modality.upper()}/{dataset_name}.zip' LOGGER.info(f'Downloading dataset {dataset_name} from {dataset_url}') try: - if IS_PRIVATE_BUCKET: - file_content = _get_data_from_private_bucket(object_key) - else: - file_content = _get_data_from_public_bucket(object_key) - - except urllib.error.HTTPError: + file_content = _get_data_from_bucket(object_key) + except Exception: raise ValueError( f"Invalid dataset name '{dataset_name}'. " 'Make sure you have the correct modality for the dataset name or ' @@ -193,11 +179,14 @@ def get_available_demos(modality): * If ``modality`` is not ``'single_table'``, ``'multi_table'`` or ``'sequential'``. """ _validate_modalities(modality) + access_key = os.environ.get('AWS_ACCESS_KEY_ID') + secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY') + signature_version = 's3v4' if access_key else UNSIGNED client = boto3.client( 's3', - aws_access_key_id=ACCESS_KEY, - aws_secret_access_key=SECRET_ACCESS_KEY, - config=CONFIG + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + config=Config(signature_version=signature_version), ) tables_info = defaultdict(list) for item in client.list_objects(Bucket=BUCKET)['Contents']: From 06a4b26c34278455cd961a4e8e95d79f9156367f Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 20 Nov 2023 17:00:33 -0600 Subject: [PATCH 06/12] update test --- tests/unit/datasets/test_demo.py | 69 ++++++++++---------------------- 1 file changed, 21 insertions(+), 48 deletions(-) diff --git a/tests/unit/datasets/test_demo.py b/tests/unit/datasets/test_demo.py index d8a3625d7..2d48171c3 100644 --- a/tests/unit/datasets/test_demo.py +++ b/tests/unit/datasets/test_demo.py @@ -1,3 +1,4 @@ +import os import re from unittest.mock import MagicMock, Mock, patch @@ -5,9 +6,7 @@ import pandas as pd import pytest -from sdv.datasets.demo import ( - _download, _get_data_from_private_bucket, _get_data_from_public_bucket, download_demo, - get_available_demos) +from sdv.datasets.demo import _download, _get_data_from_bucket, download_demo, get_available_demos def test_download_demo_invalid_modality(): @@ -67,71 +66,45 @@ def test_download_demo_single_table(tmpdir): @patch('boto3.Session') -@patch('sdv.datasets.demo.ACCESS_KEY', 'access_key') -@patch('sdv.datasets.demo.SECRET_ACCESS_KEY', 'secret_access_key') -@patch('sdv.datasets.demo.REGION_NAME', 'region_name') +@patch.dict(os.environ, { + 'AWS_ACCESS_KEY_ID': 'access_key', + 'AWS_SECRET_ACCESS_KEY': 'secret_access_key', + 'AWS_REGION': 'region_name', +}) @patch('sdv.datasets.demo.BUCKET', 'bucket') -def test__get_data_from_private_bucket(session_mock): - """Test the ``_get_data_from_private_bucket`` method.""" +def test__get_data_from_bucket(session_mock): + """Test the ``_get_data_from_bucket`` method.""" # Setup - session_mock.return_value.client.return_value.get_object.return_value = { - 'Body': MagicMock(read=MagicMock(return_value=b'')) - } + mock_s3_client = Mock() + session_mock.return_value.client.return_value = mock_s3_client + mock_s3_client.get_object.return_value = {'Body': Mock(read=lambda: b'data')} # Run - _get_data_from_private_bucket('object_key') + result = _get_data_from_bucket('object_key') # Assert + assert result == b'data' session_mock.assert_called_once_with( - aws_access_key_id='access_key', aws_secret_access_key='secret_access_key', + aws_access_key_id='access_key', + aws_secret_access_key='secret_access_key', region_name='region_name' ) - session_mock.return_value.client.assert_called_once_with('s3') - session_mock.return_value.client.return_value.get_object.assert_called_once_with( + mock_s3_client.get_object.assert_called_once_with( Bucket='bucket', Key='object_key' ) -@patch('urllib.request.urlopen') -@patch('sdv.datasets.demo.BUCKET', 'bucket') -def test__get_data_from_public_bucket(url_open_mock): - """Test the ``_get_data_from_public_bucket`` method.""" - # Setup - url_open_mock.return_value.read.return_value = b'' - - # Run - _get_data_from_public_bucket('object_key') - - # Assert - url_open_mock.assert_called_once_with('https://bucket.s3.amazonaws.com/object_key') - url_open_mock.return_value.read.assert_called_once_with() - - -@patch('sdv.datasets.demo._get_data_from_public_bucket') -def test__download_public_bucket(mock_get_data_from_public_bucket): +@patch('sdv.datasets.demo._get_data_from_bucket') +def test__download(mock_get_data_from_bucket): """Test the ``_download`` method when the bucket is public.""" # Setup - mock_get_data_from_public_bucket.return_value = b'' - - # Run - _download('single_table', 'ring') - - # Assert - mock_get_data_from_public_bucket.assert_called_once_with('SINGLE_TABLE/ring.zip') - - -@patch('sdv.datasets.demo.IS_PRIVATE_BUCKET', True) -@patch('sdv.datasets.demo._get_data_from_private_bucket') -def test__download_private_bucket(mock_get_data_from_private_bucket): - """Test the ``_download`` method when the bucket is private.""" - # Setup - mock_get_data_from_private_bucket.return_value = b'' + mock_get_data_from_bucket.return_value = b'' # Run _download('single_table', 'ring') # Assert - mock_get_data_from_private_bucket.assert_called_once_with('SINGLE_TABLE/ring.zip') + mock_get_data_from_bucket.assert_called_once_with('SINGLE_TABLE/ring.zip') def test_download_demo_single_table_no_output_folder(): From 6c892271437b4dc586cd87c3fb24f118d5589319 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 20 Nov 2023 17:20:33 -0600 Subject: [PATCH 07/12] remove is_private variable --- sdv/datasets/demo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdv/datasets/demo.py b/sdv/datasets/demo.py index 1456f0e98..e13564437 100644 --- a/sdv/datasets/demo.py +++ b/sdv/datasets/demo.py @@ -18,7 +18,6 @@ from sdv.metadata.single_table import SingleTableMetadata LOGGER = logging.getLogger(__name__) -IS_PRIVATE_BUCKET = False BUCKET = 'sdv-demo-datasets' BUCKET_URL = 'https://sdv-demo-datasets.s3.amazonaws.com' METADATA_FILENAME = 'metadata.json' From a6a8aa4dd31484d0f0988f1f34957bcc90306c13 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 20 Nov 2023 17:45:12 -0600 Subject: [PATCH 08/12] docstring --- tests/unit/datasets/test_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/datasets/test_demo.py b/tests/unit/datasets/test_demo.py index 2d48171c3..4bb30b9e1 100644 --- a/tests/unit/datasets/test_demo.py +++ b/tests/unit/datasets/test_demo.py @@ -96,7 +96,7 @@ def test__get_data_from_bucket(session_mock): @patch('sdv.datasets.demo._get_data_from_bucket') def test__download(mock_get_data_from_bucket): - """Test the ``_download`` method when the bucket is public.""" + """Test the ``_download`` method.""" # Setup mock_get_data_from_bucket.return_value = b'' From a29794473a2c3c1255a97c0cdcf4ba3cbeb42302 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Tue, 21 Nov 2023 10:18:51 -0600 Subject: [PATCH 09/12] signature version update --- sdv/datasets/demo.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sdv/datasets/demo.py b/sdv/datasets/demo.py index e13564437..0725b190d 100644 --- a/sdv/datasets/demo.py +++ b/sdv/datasets/demo.py @@ -20,6 +20,7 @@ LOGGER = logging.getLogger(__name__) BUCKET = 'sdv-demo-datasets' BUCKET_URL = 'https://sdv-demo-datasets.s3.amazonaws.com' +SIGNATURE_VERSION = UNSIGNED METADATA_FILENAME = 'metadata.json' @@ -46,8 +47,7 @@ def _get_data_from_bucket(object_key): aws_secret_access_key=secret_key, region_name=region, ) - signature_version = 's3v4' if access_key else UNSIGNED - s3 = session.client('s3', config=Config(signature_version=signature_version)) + s3 = session.client('s3', config=Config(signature_version=SIGNATURE_VERSION)) response = s3.get_object(Bucket=BUCKET, Key=object_key) return response['Body'].read() @@ -180,12 +180,11 @@ def get_available_demos(modality): _validate_modalities(modality) access_key = os.environ.get('AWS_ACCESS_KEY_ID') secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY') - signature_version = 's3v4' if access_key else UNSIGNED client = boto3.client( 's3', aws_access_key_id=access_key, aws_secret_access_key=secret_key, - config=Config(signature_version=signature_version), + config=Config(signature_version=SIGNATURE_VERSION), ) tables_info = defaultdict(list) for item in client.list_objects(Bucket=BUCKET)['Contents']: From 8cad6e2e016fc408fe008f1ce5062d8b317d702c Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Tue, 21 Nov 2023 10:21:25 -0600 Subject: [PATCH 10/12] AWS_REGION to AWS_DEFAULT_REGION --- sdv/datasets/demo.py | 2 +- tests/unit/datasets/test_demo.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdv/datasets/demo.py b/sdv/datasets/demo.py index 0725b190d..bbb308cd4 100644 --- a/sdv/datasets/demo.py +++ b/sdv/datasets/demo.py @@ -41,7 +41,7 @@ def _validate_output_folder(output_folder_name): def _get_data_from_bucket(object_key): access_key = os.environ.get('AWS_ACCESS_KEY_ID') secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY') - region = os.environ.get('AWS_REGION') + region = os.environ.get('AWS_DEFAULT_REGION') session = boto3.Session( aws_access_key_id=access_key, aws_secret_access_key=secret_key, diff --git a/tests/unit/datasets/test_demo.py b/tests/unit/datasets/test_demo.py index 4bb30b9e1..67fc6ecbb 100644 --- a/tests/unit/datasets/test_demo.py +++ b/tests/unit/datasets/test_demo.py @@ -69,7 +69,7 @@ def test_download_demo_single_table(tmpdir): @patch.dict(os.environ, { 'AWS_ACCESS_KEY_ID': 'access_key', 'AWS_SECRET_ACCESS_KEY': 'secret_access_key', - 'AWS_REGION': 'region_name', + 'AWS_DEFAULT_REGION': 'region_name', }) @patch('sdv.datasets.demo.BUCKET', 'bucket') def test__get_data_from_bucket(session_mock): From 4615f7ca1a7703976f595851e4563703aaea7064 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Tue, 21 Nov 2023 15:53:37 -0600 Subject: [PATCH 11/12] update parameters setting --- sdv/datasets/demo.py | 19 ++----------------- tests/unit/datasets/test_demo.py | 6 +----- 2 files changed, 3 insertions(+), 22 deletions(-) diff --git a/sdv/datasets/demo.py b/sdv/datasets/demo.py index bbb308cd4..24b53f0df 100644 --- a/sdv/datasets/demo.py +++ b/sdv/datasets/demo.py @@ -39,16 +39,8 @@ def _validate_output_folder(output_folder_name): def _get_data_from_bucket(object_key): - access_key = os.environ.get('AWS_ACCESS_KEY_ID') - secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY') - region = os.environ.get('AWS_DEFAULT_REGION') - session = boto3.Session( - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - region_name=region, - ) + session = boto3.Session() s3 = session.client('s3', config=Config(signature_version=SIGNATURE_VERSION)) - response = s3.get_object(Bucket=BUCKET, Key=object_key) return response['Body'].read() @@ -178,14 +170,7 @@ def get_available_demos(modality): * If ``modality`` is not ``'single_table'``, ``'multi_table'`` or ``'sequential'``. """ _validate_modalities(modality) - access_key = os.environ.get('AWS_ACCESS_KEY_ID') - secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY') - client = boto3.client( - 's3', - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - config=Config(signature_version=SIGNATURE_VERSION), - ) + client = boto3.client('s3', config=Config(signature_version=SIGNATURE_VERSION)) tables_info = defaultdict(list) for item in client.list_objects(Bucket=BUCKET)['Contents']: dataset_modality, dataset = item['Key'].split('/', 1) diff --git a/tests/unit/datasets/test_demo.py b/tests/unit/datasets/test_demo.py index 67fc6ecbb..dc1288f3d 100644 --- a/tests/unit/datasets/test_demo.py +++ b/tests/unit/datasets/test_demo.py @@ -84,11 +84,7 @@ def test__get_data_from_bucket(session_mock): # Assert assert result == b'data' - session_mock.assert_called_once_with( - aws_access_key_id='access_key', - aws_secret_access_key='secret_access_key', - region_name='region_name' - ) + session_mock.assert_called_once() mock_s3_client.get_object.assert_called_once_with( Bucket='bucket', Key='object_key' ) From 53a441e1090c8eb9e0280acf41a47ff4d4d1b64d Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Tue, 21 Nov 2023 16:57:24 -0600 Subject: [PATCH 12/12] address comments --- sdv/datasets/demo.py | 3 ++- tests/unit/datasets/test_demo.py | 6 ------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/sdv/datasets/demo.py b/sdv/datasets/demo.py index 24b53f0df..c202d91c9 100644 --- a/sdv/datasets/demo.py +++ b/sdv/datasets/demo.py @@ -13,6 +13,7 @@ import pandas as pd from botocore import UNSIGNED from botocore.client import Config +from botocore.exceptions import ClientError from sdv.metadata.multi_table import MultiTableMetadata from sdv.metadata.single_table import SingleTableMetadata @@ -51,7 +52,7 @@ def _download(modality, dataset_name): LOGGER.info(f'Downloading dataset {dataset_name} from {dataset_url}') try: file_content = _get_data_from_bucket(object_key) - except Exception: + except ClientError: raise ValueError( f"Invalid dataset name '{dataset_name}'. " 'Make sure you have the correct modality for the dataset name or ' diff --git a/tests/unit/datasets/test_demo.py b/tests/unit/datasets/test_demo.py index dc1288f3d..fb0a5446f 100644 --- a/tests/unit/datasets/test_demo.py +++ b/tests/unit/datasets/test_demo.py @@ -1,4 +1,3 @@ -import os import re from unittest.mock import MagicMock, Mock, patch @@ -66,11 +65,6 @@ def test_download_demo_single_table(tmpdir): @patch('boto3.Session') -@patch.dict(os.environ, { - 'AWS_ACCESS_KEY_ID': 'access_key', - 'AWS_SECRET_ACCESS_KEY': 'secret_access_key', - 'AWS_DEFAULT_REGION': 'region_name', -}) @patch('sdv.datasets.demo.BUCKET', 'bucket') def test__get_data_from_bucket(session_mock): """Test the ``_get_data_from_bucket`` method."""