Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set credentials key as variables #1682

Merged
merged 12 commits into from
Nov 27, 2023
Merged
19 changes: 14 additions & 5 deletions sdv/datasets/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import logging
import os
import urllib.request
from collections import defaultdict
from pathlib import Path
from zipfile import ZipFile
Expand All @@ -14,13 +13,15 @@
import pandas as pd
from botocore import UNSIGNED
from botocore.client import Config
from botocore.exceptions import ClientError

from sdv.metadata.multi_table import MultiTableMetadata
from sdv.metadata.single_table import SingleTableMetadata

LOGGER = logging.getLogger(__name__)
BUCKET = 'sdv-demo-datasets'
BUCKET_URL = 'https://sdv-demo-datasets.s3.amazonaws.com'
SIGNATURE_VERSION = UNSIGNED
METADATA_FILENAME = 'metadata.json'


Expand All @@ -38,19 +39,27 @@ def _validate_output_folder(output_folder_name):
)


def _get_data_from_bucket(object_key):
session = boto3.Session()
s3 = session.client('s3', config=Config(signature_version=SIGNATURE_VERSION))
response = s3.get_object(Bucket=BUCKET, Key=object_key)
return response['Body'].read()


def _download(modality, dataset_name):
dataset_url = f'{BUCKET_URL}/{modality.upper()}/{dataset_name}.zip'
object_key = f'{modality.upper()}/{dataset_name}.zip'
LOGGER.info(f'Downloading dataset {dataset_name} from {dataset_url}')
try:
response = urllib.request.urlopen(dataset_url)
except urllib.error.HTTPError:
file_content = _get_data_from_bucket(object_key)
except ClientError:
raise ValueError(
f"Invalid dataset name '{dataset_name}'. "
'Make sure you have the correct modality for the dataset name or '
"use 'get_available_demos' to get a list of demo datasets."
)

return io.BytesIO(response.read())
return io.BytesIO(file_content)


def _extract_data(bytes_io, output_folder_name):
Expand Down Expand Up @@ -162,7 +171,7 @@ def get_available_demos(modality):
* If ``modality`` is not ``'single_table'``, ``'multi_table'`` or ``'sequential'``.
"""
_validate_modalities(modality)
client = boto3.client('s3', config=Config(signature_version=UNSIGNED))
client = boto3.client('s3', config=Config(signature_version=SIGNATURE_VERSION))
tables_info = defaultdict(list)
for item in client.list_objects(Bucket=BUCKET)['Contents']:
dataset_modality, dataset = item['Key'].split('/', 1)
Expand Down
35 changes: 34 additions & 1 deletion tests/unit/datasets/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pandas as pd
import pytest

from sdv.datasets.demo import download_demo, get_available_demos
from sdv.datasets.demo import _download, _get_data_from_bucket, download_demo, get_available_demos


def test_download_demo_invalid_modality():
Expand Down Expand Up @@ -64,6 +64,39 @@ def test_download_demo_single_table(tmpdir):
assert metadata.to_dict() == expected_metadata_dict


@patch('boto3.Session')
@patch('sdv.datasets.demo.BUCKET', 'bucket')
def test__get_data_from_bucket(session_mock):
"""Test the ``_get_data_from_bucket`` method."""
# Setup
mock_s3_client = Mock()
session_mock.return_value.client.return_value = mock_s3_client
mock_s3_client.get_object.return_value = {'Body': Mock(read=lambda: b'data')}

# Run
result = _get_data_from_bucket('object_key')

# Assert
assert result == b'data'
session_mock.assert_called_once()
mock_s3_client.get_object.assert_called_once_with(
Bucket='bucket', Key='object_key'
)


@patch('sdv.datasets.demo._get_data_from_bucket')
def test__download(mock_get_data_from_bucket):
"""Test the ``_download`` method."""
# Setup
mock_get_data_from_bucket.return_value = b''

# Run
_download('single_table', 'ring')

# Assert
mock_get_data_from_bucket.assert_called_once_with('SINGLE_TABLE/ring.zip')


def test_download_demo_single_table_no_output_folder():
"""Test it can download a single table dataset when no output folder is passed."""
# Run
Expand Down
Loading