diff --git a/sdv/datasets/demo.py b/sdv/datasets/demo.py index f92639863..c202d91c9 100644 --- a/sdv/datasets/demo.py +++ b/sdv/datasets/demo.py @@ -4,7 +4,6 @@ import json import logging import os -import urllib.request from collections import defaultdict from pathlib import Path from zipfile import ZipFile @@ -14,6 +13,7 @@ import pandas as pd from botocore import UNSIGNED from botocore.client import Config +from botocore.exceptions import ClientError from sdv.metadata.multi_table import MultiTableMetadata from sdv.metadata.single_table import SingleTableMetadata @@ -21,6 +21,7 @@ LOGGER = logging.getLogger(__name__) BUCKET = 'sdv-demo-datasets' BUCKET_URL = 'https://sdv-demo-datasets.s3.amazonaws.com' +SIGNATURE_VERSION = UNSIGNED METADATA_FILENAME = 'metadata.json' @@ -38,19 +39,27 @@ def _validate_output_folder(output_folder_name): ) +def _get_data_from_bucket(object_key): + session = boto3.Session() + s3 = session.client('s3', config=Config(signature_version=SIGNATURE_VERSION)) + response = s3.get_object(Bucket=BUCKET, Key=object_key) + return response['Body'].read() + + def _download(modality, dataset_name): dataset_url = f'{BUCKET_URL}/{modality.upper()}/{dataset_name}.zip' + object_key = f'{modality.upper()}/{dataset_name}.zip' LOGGER.info(f'Downloading dataset {dataset_name} from {dataset_url}') try: - response = urllib.request.urlopen(dataset_url) - except urllib.error.HTTPError: + file_content = _get_data_from_bucket(object_key) + except ClientError: raise ValueError( f"Invalid dataset name '{dataset_name}'. " 'Make sure you have the correct modality for the dataset name or ' "use 'get_available_demos' to get a list of demo datasets." ) - return io.BytesIO(response.read()) + return io.BytesIO(file_content) def _extract_data(bytes_io, output_folder_name): @@ -162,7 +171,7 @@ def get_available_demos(modality): * If ``modality`` is not ``'single_table'``, ``'multi_table'`` or ``'sequential'``. """ _validate_modalities(modality) - client = boto3.client('s3', config=Config(signature_version=UNSIGNED)) + client = boto3.client('s3', config=Config(signature_version=SIGNATURE_VERSION)) tables_info = defaultdict(list) for item in client.list_objects(Bucket=BUCKET)['Contents']: dataset_modality, dataset = item['Key'].split('/', 1) diff --git a/tests/unit/datasets/test_demo.py b/tests/unit/datasets/test_demo.py index f14672ffc..fb0a5446f 100644 --- a/tests/unit/datasets/test_demo.py +++ b/tests/unit/datasets/test_demo.py @@ -5,7 +5,7 @@ import pandas as pd import pytest -from sdv.datasets.demo import download_demo, get_available_demos +from sdv.datasets.demo import _download, _get_data_from_bucket, download_demo, get_available_demos def test_download_demo_invalid_modality(): @@ -64,6 +64,39 @@ def test_download_demo_single_table(tmpdir): assert metadata.to_dict() == expected_metadata_dict +@patch('boto3.Session') +@patch('sdv.datasets.demo.BUCKET', 'bucket') +def test__get_data_from_bucket(session_mock): + """Test the ``_get_data_from_bucket`` method.""" + # Setup + mock_s3_client = Mock() + session_mock.return_value.client.return_value = mock_s3_client + mock_s3_client.get_object.return_value = {'Body': Mock(read=lambda: b'data')} + + # Run + result = _get_data_from_bucket('object_key') + + # Assert + assert result == b'data' + session_mock.assert_called_once() + mock_s3_client.get_object.assert_called_once_with( + Bucket='bucket', Key='object_key' + ) + + +@patch('sdv.datasets.demo._get_data_from_bucket') +def test__download(mock_get_data_from_bucket): + """Test the ``_download`` method.""" + # Setup + mock_get_data_from_bucket.return_value = b'' + + # Run + _download('single_table', 'ring') + + # Assert + mock_get_data_from_bucket.assert_called_once_with('SINGLE_TABLE/ring.zip') + + def test_download_demo_single_table_no_output_folder(): """Test it can download a single table dataset when no output folder is passed.""" # Run