From a7a42d153ebba4de3fd478a4257890c8c861b1ad Mon Sep 17 00:00:00 2001 From: Tibor Casteleijn Date: Tue, 14 Jan 2025 21:58:26 +0100 Subject: [PATCH 1/2] refactor to functional pattern --- README.md | 23 +- src/knmi_dataset_downloader/__init__.py | 12 +- src/knmi_dataset_downloader/cli.py | 18 +- src/knmi_dataset_downloader/dataset.py | 318 ++++++++++++++++++++++ src/knmi_dataset_downloader/downloader.py | 304 --------------------- tests/integration/test_cli.py | 60 +--- tests/integration/test_dataset.py | 79 ++++++ tests/integration/test_downloader.py | 123 --------- 8 files changed, 443 insertions(+), 494 deletions(-) create mode 100644 src/knmi_dataset_downloader/dataset.py delete mode 100644 src/knmi_dataset_downloader/downloader.py create mode 100644 tests/integration/test_dataset.py delete mode 100644 tests/integration/test_downloader.py diff --git a/README.md b/README.md index d942c0f..f11fba4 100644 --- a/README.md +++ b/README.md @@ -70,26 +70,27 @@ Options: You can also use the package in your Python code: ```python -from knmi_dataset_downloader import Downloader +from knmi_dataset_downloader import dataset import asyncio from datetime import datetime async def main(): - # Initialize the downloader with your own API key - downloader = Downloader( - dataset_name="Actuele10mindataKNMIstations", - version="2", - max_concurrent=10, - api_key="YOUR_API_KEY", # Optional - will use anonymous API key if not provided - output_dir="path/to/output" # Optional - will use default if not provided - ) - # Download files for a specific date range - await downloader.download( + stats = await dataset.download( + api_key="YOUR_API_KEY", # Optional - will use anonymous API key if not provided + dataset_name="Actuele10mindataKNMIstations", # Optional - uses default if not provided + version="2", # Optional - uses default if not provided + max_concurrent=10, # Optional - uses default if not provided + output_dir="path/to/output", # Optional - uses default if not provided start_date=datetime(2024, 1, 1), end_date=datetime(2024, 1, 31), limit=5 # Optional - limit the number of files to download ) + + # Access download statistics + print(f"Total files found: {stats.total_files}") + print(f"Files downloaded: {stats.downloaded_files}") + print(f"Files skipped: {stats.skipped_files}") # Run the download if __name__ == "__main__": diff --git a/src/knmi_dataset_downloader/__init__.py b/src/knmi_dataset_downloader/__init__.py index 7fb1737..947005f 100644 --- a/src/knmi_dataset_downloader/__init__.py +++ b/src/knmi_dataset_downloader/__init__.py @@ -1 +1,11 @@ -from .downloader import Downloader \ No newline at end of file +from .dataset import download, DownloadStats +from .defaults import DEFAULT_DATASET_NAME, DEFAULT_DATASET_VERSION, DEFAULT_MAX_CONCURRENT, DEFAULT_OUTPUT_DIR + +__all__ = [ + 'download', + 'DownloadStats', + 'DEFAULT_DATASET_NAME', + 'DEFAULT_DATASET_VERSION', + 'DEFAULT_MAX_CONCURRENT', + 'DEFAULT_OUTPUT_DIR' +] \ No newline at end of file diff --git a/src/knmi_dataset_downloader/cli.py b/src/knmi_dataset_downloader/cli.py index ca65ed2..db81992 100644 --- a/src/knmi_dataset_downloader/cli.py +++ b/src/knmi_dataset_downloader/cli.py @@ -4,7 +4,7 @@ import asyncio from datetime import datetime from pathlib import Path -from .downloader import Downloader +from . import dataset from .defaults import ( DEFAULT_OUTPUT_DIR, DEFAULT_DATASET_NAME, @@ -53,7 +53,7 @@ async def async_main() -> None: '--start-date', '-s', default=default_start.isoformat(), help='Start date in ISO 8601 format example: 2024-01-01T00:00:00 or 2024-01-01, ' - 'default is 30 minutes ago' + 'default is 1 hour and 30 minutes ago' ) parser.add_argument( '--end-date', '-e', @@ -94,17 +94,17 @@ async def async_main() -> None: print("Please provide an API key using the --api-key argument") return - # Initialize downloader - downloader = Downloader( + # Download files + await dataset.download( + api_key=api_key, dataset_name=args.dataset, version=args.version, max_concurrent=args.concurrent, - api_key=api_key, - output_dir=args.output_dir + output_dir=args.output_dir, + start_date=start, + end_date=end, + limit=args.limit ) - - # Download files - await downloader.download(start_date=start, end_date=end, limit=args.limit) def main() -> None: """Synchronous wrapper for async_main.""" diff --git a/src/knmi_dataset_downloader/dataset.py b/src/knmi_dataset_downloader/dataset.py new file mode 100644 index 0000000..f29ed47 --- /dev/null +++ b/src/knmi_dataset_downloader/dataset.py @@ -0,0 +1,318 @@ +from __future__ import annotations + +import asyncio +from typing import List, NamedTuple +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path + +import aiofiles +import httpx +from tqdm.asyncio import tqdm +from kiota_abstractions.authentication.api_key_authentication_provider import ( + ApiKeyAuthenticationProvider, + KeyLocation, +) +from kiota_http.httpx_request_adapter import HttpxRequestAdapter +from kiota_serialization_json.json_serialization_writer_factory import ( + JsonSerializationWriterFactory, +) +from kiota_serialization_json.json_parse_node_factory import JsonParseNodeFactory + +from .knmi_dataset_api.api_client import ApiClient +from .knmi_dataset_api.v1.datasets.item.versions.item.files.files_request_builder import ( + FilesRequestBuilder, +) +from .knmi_dataset_api.v1.datasets.item.versions.item.files.get_order_by_query_parameter_type import ( + GetOrderByQueryParameterType, +) +from .knmi_dataset_api.v1.datasets.item.versions.item.files.get_sorting_query_parameter_type import ( + GetSortingQueryParameterType, +) +from .defaults import ( + DEFAULT_OUTPUT_DIR, + DEFAULT_DATASET_NAME, + DEFAULT_DATASET_VERSION, + DEFAULT_MAX_CONCURRENT, + get_default_date_range, +) +from .api_key import get_anonymous_api_key + +import logging +log = logging.getLogger(__name__) + +@dataclass +class DownloadStats: + """Statistics for the download process.""" + total_files: int = 0 + skipped_files: int = 0 + downloaded_files: int = 0 + failed_files: List[str] = field(default_factory=list) + total_bytes_downloaded: int = 0 + +class DownloadContext(NamedTuple): + """Context for download operations.""" + client: ApiClient + http_client: httpx.AsyncClient + semaphore: asyncio.Semaphore + dataset_name: str + version: str + output_dir: Path + stats: DownloadStats + +def initialize_client(api_key: str) -> ApiClient: + """Initialize the KNMI API client with proper authentication and serialization. + + Args: + api_key (str): The API key for authentication + + Returns: + ApiClient: Configured API client + """ + auth_provider = ApiKeyAuthenticationProvider( + api_key=api_key, + parameter_name="Authorization", + key_location=KeyLocation.Header, + ) + + request_adapter = HttpxRequestAdapter( + authentication_provider=auth_provider, + parse_node_factory=JsonParseNodeFactory(), + serialization_writer_factory=JsonSerializationWriterFactory(), + base_url="https://api.dataplatform.knmi.nl/open-data", + ) + + return ApiClient(request_adapter) + +def format_size(size_bytes: int) -> str: + """Format bytes into human readable string.""" + for unit in ["B", "KB", "MB", "GB"]: + if size_bytes < 1024: + return f"{size_bytes:.1f} {unit}" + size_bytes = int(size_bytes / 1024) + return f"{size_bytes:.1f} TB" + +async def get_files_list( + context: DownloadContext, + start_date: datetime | None = None, + end_date: datetime | None = None, + limit: int | None = None, +) -> List: + """Get list of files for the specified date range. + + Args: + context (DownloadContext): Download context containing client and configuration + start_date (datetime | None): Start date for the files. Defaults to 1 day ago. + end_date (datetime | None): End date for the files. Defaults to now. + limit (int | None): Maximum number of files to retrieve. Defaults to None. + + Returns: + List[FileInfo]: List of file information objects from the KNMI API + """ + # Use default date range if not specified + if start_date is None or end_date is None: + default_start, default_end = get_default_date_range() + start_date = start_date or default_start + end_date = end_date or default_end + + if isinstance(start_date, datetime): + begin = start_date.strftime("%Y-%m-%dT%H:%M:%S+00:00") + else: + begin = None + + if isinstance(end_date, datetime): + end = end_date.strftime("%Y-%m-%dT%H:%M:%S+00:00") + else: + end = None + + config = FilesRequestBuilder.FilesRequestBuilderGetQueryParameters( + max_keys=limit, + order_by=GetOrderByQueryParameterType.LastModified, + sorting=GetSortingQueryParameterType.Desc, + begin=begin, + end=end, + ) + + request_configuration = FilesRequestBuilder.FilesRequestBuilderGetRequestConfiguration( + query_parameters=config + ) + + response = await ( + context.client.v1.datasets.by_dataset_name(context.dataset_name) + .versions.by_version_id(context.version) + .files.get(request_configuration=request_configuration) + ) + + if response is None: + raise ValueError("No response from API") + + all_files = response.files or [] + + # Handle pagination if there are more files + while response.is_truncated: + config.next_page_token = response.next_page_token + response = await ( + context.client.v1.datasets.by_dataset_name(context.dataset_name) + .versions.by_version_id(context.version) + .files.get(request_configuration=request_configuration) + ) + if response is None: + raise ValueError("No response from API") + all_files.extend(response.files or []) + if limit is not None and len(all_files) >= limit: + break + + return all_files[:limit] + +async def download_file(context: DownloadContext, filename: str, main_progress: tqdm) -> None: + """Download a single file from the dataset. + + Args: + context (DownloadContext): Download context containing clients and configuration + filename (str): Name of the file to download + main_progress (tqdm): Main progress bar for overall progress + """ + async with context.semaphore: # Limit concurrent downloads + output_path = context.output_dir / filename + output_path.parent.mkdir(parents=True, exist_ok=True) # Ensure directory exists + + if output_path.exists(): + context.stats.skipped_files += 1 + main_progress.update(1) + return + + try: + download_url = await ( + context.client.v1.datasets.by_dataset_name(context.dataset_name) + .versions.by_version_id(context.version) + .files.by_filename(filename=filename) + .url.get() + ) + + if download_url is None or download_url.temporary_download_url is None: + raise ValueError("No download URL found") + + # Get file size with a HEAD request + async with context.http_client.stream( + "HEAD", download_url.temporary_download_url + ) as response: + total_size = int(response.headers.get("content-length", 0)) + + # Create progress bar for this file + file_progress = tqdm( + total=total_size, + unit="iB", + unit_scale=True, + desc=f"Downloading {filename}", + leave=False, + ) + + # Stream the download with progress + async with context.http_client.stream( + "GET", download_url.temporary_download_url + ) as response: + response.raise_for_status() + async with aiofiles.open(output_path, mode="wb") as f: + downloaded_size = 0 + async for chunk in response.aiter_bytes(chunk_size=8192): + await f.write(chunk) + chunk_size = len(chunk) + downloaded_size += chunk_size + file_progress.update(chunk_size) + + file_progress.close() + main_progress.update(1) + + context.stats.downloaded_files += 1 + context.stats.total_bytes_downloaded += downloaded_size + log.debug(f"Successfully downloaded: {filename} ({downloaded_size / 1024 / 1024:.1f} MB)") + + except Exception as e: + log.error(f"Error downloading {filename}: {str(e)}") + context.stats.failed_files.append(filename) + if output_path.exists(): + output_path.unlink() # Remove partially downloaded file + raise + +async def download( + api_key: str | None = None, + dataset_name: str = DEFAULT_DATASET_NAME, + version: str = DEFAULT_DATASET_VERSION, + max_concurrent: int = DEFAULT_MAX_CONCURRENT, + output_dir: str | Path = DEFAULT_OUTPUT_DIR, + start_date: datetime | None = None, + end_date: datetime | None = None, + limit: int | None = None, +) -> DownloadStats: + """Download dataset files for the specified date range. + + Args: + api_key (str | None): KNMI API key. If None, an anonymous API key is used. + dataset_name (str): Name of the dataset. + version (str): Version of the dataset. + max_concurrent (int): Maximum number of concurrent downloads. + output_dir (str | Path): Output directory for downloaded files. + start_date (datetime | None): Start date for files to download. Defaults to 1 hour and 30 minutes ago. + end_date (datetime | None): End date for files to download. Defaults to now. + limit (int | None): Maximum number of files to download. If None, downloads all files. + + Returns: + DownloadStats: Statistics about the download process + """ + if not api_key: + api_key = await get_anonymous_api_key() + + # Initialize clients and context + client = initialize_client(api_key) + http_client = httpx.AsyncClient() + stats = DownloadStats() + + context = DownloadContext( + client=client, + http_client=http_client, + semaphore=asyncio.Semaphore(max_concurrent), + dataset_name=dataset_name, + version=version, + output_dir=Path(output_dir), + stats=stats + ) + + try: + files = await get_files_list(context, start_date, end_date, limit) + + context.stats.total_files = len(files) + log.info(f"Found {len(files)} files in date range {start_date} to {end_date}") + + # Main progress bar for overall progress + with tqdm( + total=len(files), desc="Overall Progress", unit="file" + ) as main_progress: + # Download files concurrently with semaphore limiting + tasks = [ + download_file(context, file.filename, main_progress) for file in files + ] + await asyncio.gather(*tasks, return_exceptions=True) + + # Print summary + # fmt: off + log.info("\nDownload Summary:") + log.info(f"Total files found: {context.stats.total_files}") + log.info(f"Files already present: {context.stats.skipped_files}") + log.info(f"Files downloaded: {context.stats.downloaded_files}") + log.info(f"Failed downloads: {len(context.stats.failed_files)}") + log.info(f"Total data downloaded: {format_size(context.stats.total_bytes_downloaded)}") + # fmt: on + + if context.stats.failed_files: + log.warning("\nFailed downloads:") + for filename in context.stats.failed_files: + log.warning(f"- {filename}") + + except Exception as e: + log.error(f"Error during download process: {str(e)}") + raise + + finally: + await http_client.aclose() # Ensure HTTP client is properly closed + + return context.stats diff --git a/src/knmi_dataset_downloader/downloader.py b/src/knmi_dataset_downloader/downloader.py deleted file mode 100644 index 7d7d522..0000000 --- a/src/knmi_dataset_downloader/downloader.py +++ /dev/null @@ -1,304 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import List -from dataclasses import dataclass, field -from datetime import datetime -from pathlib import Path - -import aiofiles -import httpx -from tqdm.asyncio import tqdm -from kiota_abstractions.authentication.api_key_authentication_provider import ( - ApiKeyAuthenticationProvider, - KeyLocation, -) -from kiota_http.httpx_request_adapter import HttpxRequestAdapter -from kiota_serialization_json.json_serialization_writer_factory import ( - JsonSerializationWriterFactory, -) -from kiota_serialization_json.json_parse_node_factory import JsonParseNodeFactory - -from .knmi_dataset_api.api_client import ApiClient -from .knmi_dataset_api.v1.datasets.item.versions.item.files.files_request_builder import ( - FilesRequestBuilder, -) -from .knmi_dataset_api.v1.datasets.item.versions.item.files.get_order_by_query_parameter_type import ( - GetOrderByQueryParameterType, -) -from .knmi_dataset_api.v1.datasets.item.versions.item.files.get_sorting_query_parameter_type import ( - GetSortingQueryParameterType, -) -from .defaults import ( - DEFAULT_OUTPUT_DIR, - DEFAULT_DATASET_NAME, - DEFAULT_DATASET_VERSION, - DEFAULT_MAX_CONCURRENT, - get_default_date_range, -) -from .api_key import get_anonymous_api_key - -import logging -log = logging.getLogger(__name__) - -@dataclass -class DownloadStats: - """Statistics for the download process.""" - - total_files: int = 0 - skipped_files: int = 0 - downloaded_files: int = 0 - failed_files: List[str] = field(default_factory=list) - total_bytes_downloaded: int = 0 - - -class Downloader: - def __init__( - self, - api_key: str | None = None, - dataset_name: str = DEFAULT_DATASET_NAME, - version: str = DEFAULT_DATASET_VERSION, - max_concurrent: int = DEFAULT_MAX_CONCURRENT, - output_dir: str | Path = DEFAULT_OUTPUT_DIR, - ): - """Initialize the KNMI Dataset client. - - Args: - dataset_name (str, optional): Name of the dataset. Defaults to DEFAULT_DATASET_NAME. - version (str, optional): Version of the dataset. Defaults to DEFAULT_DATASET_VERSION. - max_concurrent (int, optional): Maximum number of concurrent downloads. Defaults to DEFAULT_MAX_CONCURRENT. - api_key (str): KNMI API key. - output_dir (str | Path, optional): Output directory for downloaded files. Defaults to DEFAULT_OUTPUT_DIR. - """ - if not api_key: - raise ValueError("API key is required") - - self.dataset_name = dataset_name - self.version = version - self.max_concurrent = max_concurrent - self.api_key = api_key or asyncio.run(get_anonymous_api_key()) - self.output_dir = Path(output_dir) - self.semaphore = asyncio.Semaphore(max_concurrent) - self.client = self._initialize_client() - self.http_client = httpx.AsyncClient() - self.stats = DownloadStats() - - def _initialize_client(self) -> ApiClient: - """Initialize the KNMI API client with proper authentication and serialization. - - Returns: - WeatherDataClient: Configured API client - """ - auth_provider = ApiKeyAuthenticationProvider( - api_key=self.api_key, - parameter_name="Authorization", - key_location=KeyLocation.Header, - ) - - request_adapter = HttpxRequestAdapter( - authentication_provider=auth_provider, - parse_node_factory=JsonParseNodeFactory(), - serialization_writer_factory=JsonSerializationWriterFactory(), - base_url="https://api.dataplatform.knmi.nl/open-data", - ) - - return ApiClient(request_adapter) - - async def _get_files_list( - self, - start_date: datetime | None = None, - end_date: datetime | None = None, - max_keys: int = 1000, - ) -> List: - """Get list of files for the specified date range. - - Args: - start_date (datetime): Start date for the files. Defaults to 1 day ago. - end_date (datetime): End date for the files. Defaults to now. - max_keys (int, optional): Maximum number of files to retrieve per page. Defaults to 1000. - - Returns: - List: List of file information - """ - # Use default date range if not specified - if start_date is None or end_date is None: - default_start, default_end = get_default_date_range() - start_date = start_date or default_start - end_date = end_date or default_end - - if isinstance(start_date, datetime): - begin = start_date.strftime("%Y-%m-%dT%H:%M:%S+00:00") - else: - begin = None - - if isinstance(end_date, datetime): - end = end_date.strftime("%Y-%m-%dT%H:%M:%S+00:00") - else: - end = None - - config = FilesRequestBuilder.FilesRequestBuilderGetQueryParameters( - max_keys=max_keys, - order_by=GetOrderByQueryParameterType.LastModified, - sorting=GetSortingQueryParameterType.Desc, - begin=begin, - end=end, - ) - - request_configuration = ( - FilesRequestBuilder.FilesRequestBuilderGetRequestConfiguration( - query_parameters=config - ) - ) - - response = await ( - self.client.v1.datasets.by_dataset_name(self.dataset_name) - .versions.by_version_id(self.version) - .files.get(request_configuration=request_configuration) - ) - - if response is None: - raise ValueError("No response from API") - - all_files = response.files or [] - - # Handle pagination if there are more files - while response.is_truncated: - config.next_page_token = response.next_page_token - response = await ( - self.client.v1.datasets.by_dataset_name(self.dataset_name) - .versions.by_version_id(self.version) - .files.get(request_configuration=request_configuration) - ) - if response is None: - raise ValueError("No response from API") - all_files.extend(response.files or []) - - return all_files - - async def _download_file(self, filename: str, main_progress: tqdm) -> None: - """Download a single file from the dataset. - - Args: - filename (str): Name of the file to download - main_progress (tqdm): Main progress bar for overall progress - """ - async with self.semaphore: # Limit concurrent downloads - output_path = self.output_dir / filename - output_path.parent.mkdir(parents=True, exist_ok=True) # Ensure directory exists - - if output_path.exists(): - self.stats.skipped_files += 1 - main_progress.update(1) - return - - try: - download_url = await ( - self.client.v1.datasets.by_dataset_name(self.dataset_name) - .versions.by_version_id(self.version) - .files.by_filename(filename=filename) - .url.get() - ) - - if download_url is None or download_url.temporary_download_url is None: - raise ValueError("No download URL found") - - # Get file size with a HEAD request - async with self.http_client.stream( - "HEAD", download_url.temporary_download_url - ) as response: - total_size = int(response.headers.get("content-length", 0)) - - # Create progress bar for this file - file_progress = tqdm( - total=total_size, - unit="iB", - unit_scale=True, - desc=f"Downloading {filename}", - leave=False, - ) - - # Stream the download with progress - async with self.http_client.stream( - "GET", download_url.temporary_download_url - ) as response: - response.raise_for_status() - async with aiofiles.open(output_path, mode="wb") as f: - downloaded_size = 0 - async for chunk in response.aiter_bytes(chunk_size=8192): - await f.write(chunk) - chunk_size = len(chunk) - downloaded_size += chunk_size - file_progress.update(chunk_size) - - file_progress.close() - main_progress.update(1) - - self.stats.downloaded_files += 1 - self.stats.total_bytes_downloaded += downloaded_size - log.debug(f"Successfully downloaded: {filename} ({downloaded_size / 1024 / 1024:.1f} MB)") - - except Exception as e: - log.error(f"Error downloading {filename}: {str(e)}") - self.stats.failed_files.append(filename) - if output_path.exists(): - output_path.unlink() # Remove partially downloaded file - raise - - def _format_size(self, size_bytes: int) -> str: - """Format bytes into human readable string.""" - for unit in ["B", "KB", "MB", "GB"]: - if size_bytes < 1024: - return f"{size_bytes:.1f} {unit}" - size_bytes = int(size_bytes / 1024) - return f"{size_bytes:.1f} TB" - - async def download( - self, start_date: datetime | None = None, end_date: datetime | None = None, limit: int | None = None - ) -> None: - """Download all dataset files for the specified date range. - - Args: - start_date (datetime): Start date for the files to download. Defaults to 1 day ago. - end_date (datetime): End date for the files to download. Defaults to now. - limit (int, optional): Maximum number of files to download. If None, downloads all files. - """ - try: - files = await self._get_files_list(start_date, end_date) - if limit is not None: - files = files[:limit] - log.debug(f"Limiting download to {limit} files") - - self.stats.total_files = len(files) - log.info(f"Found {len(files)} files in date range {start_date} to {end_date}") - - # Main progress bar for overall progress - with tqdm( - total=len(files), desc="Overall Progress", unit="file" - ) as main_progress: - # Download files concurrently with semaphore limiting - tasks = [ - self._download_file(file.filename, main_progress) for file in files - ] - await asyncio.gather(*tasks, return_exceptions=True) - - # Print summary - # fmt: off - log.info("\nDownload Summary:") - log.info(f"Total files found: {self.stats.total_files}") - log.info(f"Files already present: {self.stats.skipped_files}") - log.info(f"Files downloaded: {self.stats.downloaded_files}") - log.info(f"Failed downloads: {len(self.stats.failed_files)}") - log.info(f"Total data downloaded: {self._format_size(self.stats.total_bytes_downloaded)}") - # fmt: on - - if self.stats.failed_files: - log.warning("\nFailed downloads:") - for filename in self.stats.failed_files: - log.warning(f"- {filename}") - - except Exception as e: - log.error(f"Error during download process: {str(e)}") - raise - - finally: - await self.http_client.aclose() # Ensure HTTP client is properly closed diff --git a/tests/integration/test_cli.py b/tests/integration/test_cli.py index a4a0646..afd1205 100644 --- a/tests/integration/test_cli.py +++ b/tests/integration/test_cli.py @@ -1,60 +1,23 @@ import unittest -import asyncio -from unittest.mock import patch -from datetime import datetime -import argparse +import tempfile from pathlib import Path -from src.knmi_dataset_downloader.cli import async_main, parse_date +import shutil +from unittest.mock import patch + +from src.knmi_dataset_downloader.cli import async_main from src.knmi_dataset_downloader.api_key import get_anonymous_api_key -from src.knmi_dataset_downloader.defaults import ( - DEFAULT_DATASET_NAME, - DEFAULT_DATASET_VERSION, -) class TestCLI(unittest.IsolatedAsyncioTestCase): - def setUp(self): + async def asyncSetUp(self): """Set up test fixtures.""" - # Anonymous key from https://developer.dataplatform.knmi.nl/open-data-api#token - self.api_key = asyncio.run(get_anonymous_api_key()) - # Create test downloads directory - self.test_output_dir = Path('test_downloads') - self.test_output_dir.mkdir(exist_ok=True) - - def tearDown(self): - """Clean up test files.""" - # Remove all .nc files in test directory - for nc_file in self.test_output_dir.glob('*.nc'): - nc_file.unlink() - # Try to remove directory (will only work if empty) - try: - self.test_output_dir.rmdir() - except OSError: - pass - - async def test_parse_date(self): - """Test date parsing functionality.""" - # Test valid dates in different ISO 8601 formats - self.assertEqual(parse_date("2023-01-01T00:00:00"), datetime(2023, 1, 1, 0, 0, 0)) - self.assertEqual(parse_date("2023-01-01T12:30:45"), datetime(2023, 1, 1, 12, 30, 45)) - self.assertEqual(parse_date("2023-01-01"), datetime(2023, 1, 1)) - - # Test empty date - self.assertIsNone(parse_date("")) - - # Test invalid date formats - with self.assertRaises(argparse.ArgumentTypeError): - parse_date("2023/01/01") - with self.assertRaises(argparse.ArgumentTypeError): - parse_date("invalid") - with self.assertRaises(argparse.ArgumentTypeError): - parse_date("2023-13-01") # Invalid month + self.api_key = await get_anonymous_api_key() + # Create a temporary directory for test outputs + self.test_output_dir = Path(tempfile.mkdtemp()) async def test_cli_with_real_api(self): """Test CLI with real API using anonymous key.""" test_args = [ '--api-key', self.api_key, - '--dataset', DEFAULT_DATASET_NAME, - '--version', DEFAULT_DATASET_VERSION, '--start-date', '2024-01-01T00:00:00', '--end-date', '2024-01-01T00:20:00', # Just 20 minutes of data '--concurrent', '2', @@ -88,5 +51,10 @@ async def test_cli_with_defaults(self): self.assertGreater(len(nc_files), 0, "No .nc files were downloaded with default arguments") + async def asyncTearDown(self): + """Clean up after tests.""" + # Clean up the temp directory + shutil.rmtree(self.test_output_dir, ignore_errors=True) + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tests/integration/test_dataset.py b/tests/integration/test_dataset.py new file mode 100644 index 0000000..781088f --- /dev/null +++ b/tests/integration/test_dataset.py @@ -0,0 +1,79 @@ +import unittest +import os +from datetime import datetime +import tempfile +from pathlib import Path +import shutil + +from src.knmi_dataset_downloader import download, DownloadStats +from src.knmi_dataset_downloader.api_key import get_anonymous_api_key + +class TestDownloader(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + """Set up test fixtures.""" + # self.api_key = await get_anonymous_api_key() + self.api_key = "eyJvcmciOiI1ZTU1NGUxOTI3NGE5NjAwMDEyYTNlYjEiLCJpZCI6IjI0N2M3NDRkNjdhOTQ0NGY5ODdjYmFlNjllYjhmZGY5IiwiaCI6Im11cm11cjEyOCJ9" + # Create a temporary directory for test outputs + self.temp_dir = Path(tempfile.mkdtemp()) + + def test_download_stats(self): + """Test DownloadStats initialization and updates.""" + stats = DownloadStats() + self.assertEqual(stats.total_files, 0) + self.assertEqual(stats.skipped_files, 0) + self.assertEqual(stats.downloaded_files, 0) + self.assertEqual(stats.failed_files, []) + self.assertEqual(stats.total_bytes_downloaded, 0) + + async def test_download(self): + """Test file download functionality.""" + # Test with a small date range and limit to 1 file + start_date = datetime(2024, 1, 1, 0, 0, 0) + end_date = datetime(2024, 1, 1, 0, 30, 0) # Just 30 minutes + + try: + stats = await download( + api_key=self.api_key, + output_dir=self.temp_dir, + start_date=start_date, + end_date=end_date, + limit=1 + ) + # Check if stats were updated + self.assertEqual(stats.total_files, 1, "Should only download 1 file") + self.assertLessEqual(stats.downloaded_files + stats.skipped_files, 1) + self.assertEqual(len(stats.failed_files), 0, "No files should fail") + except Exception as e: + self.fail(f"Download failed with error: {str(e)}") + + async def test_download_with_limit(self): + """Test download limit functionality.""" + start_date = datetime(2024, 1, 1, 0, 0, 0) + end_date = datetime(2024, 1, 1, 23, 59, 59) # Full day + + # Test with different limits + for limit in [1, 2]: + # Create a separate temp dir for each test to avoid conflicts + stats = await download( + api_key=self.api_key, + output_dir=self.temp_dir, + start_date=start_date, + end_date=end_date, + limit=limit + ) + self.assertEqual(stats.total_files, limit, + f"Should limit to {limit} files") + self.assertLessEqual( + stats.downloaded_files + stats.skipped_files, + limit, + f"Total processed files should not exceed limit of {limit}" + ) + + async def asyncTearDown(self): + """Clean up after tests.""" + # Clean up the temp directory + if hasattr(self, 'temp_dir'): + shutil.rmtree(self.temp_dir, ignore_errors=True) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/integration/test_downloader.py b/tests/integration/test_downloader.py deleted file mode 100644 index 863a95e..0000000 --- a/tests/integration/test_downloader.py +++ /dev/null @@ -1,123 +0,0 @@ -import unittest -from datetime import datetime -import tempfile -from pathlib import Path -import shutil - -from src.knmi_dataset_downloader import Downloader -from src.knmi_dataset_downloader.downloader import DownloadStats -from src.knmi_dataset_downloader.api_key import get_anonymous_api_key - -class TestDownloader(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): - """Set up test fixtures.""" - self.api_key = await get_anonymous_api_key() - # Create a temporary directory for test outputs - self.temp_dir = Path(tempfile.mkdtemp()) - self.dataset = Downloader( - dataset_name="Actuele10mindataKNMIstations", - version="2", - max_concurrent=10, - api_key=self.api_key, - output_dir=self.temp_dir - ) - - def test_init(self): - """Test initialization of Downloader.""" - self.assertEqual(self.dataset.dataset_name, "Actuele10mindataKNMIstations") - self.assertEqual(self.dataset.version, "2") - self.assertEqual(self.dataset.max_concurrent, 10) - self.assertEqual(self.dataset.api_key, self.api_key) - self.assertEqual(self.dataset.output_dir, self.temp_dir) - self.assertIsNotNone(self.dataset.semaphore) - self.assertIsNotNone(self.dataset.client) - self.assertIsNotNone(self.dataset.http_client) - self.assertIsInstance(self.dataset.stats, DownloadStats) - - def test_format_size(self): - """Test size formatting.""" - test_cases = [ - (500, "500.0 B"), - (1024, "1.0 KB"), - (1024 * 1024, "1.0 MB"), - (1024 * 1024 * 1024, "1.0 GB"), - (1024 * 1024 * 1024 * 1024, "1.0 TB"), - ] - for size, expected in test_cases: - self.assertEqual(self.dataset._format_size(size), expected) - - def test_download_stats(self): - """Test DownloadStats initialization and updates.""" - stats = DownloadStats() - self.assertEqual(stats.total_files, 0) - self.assertEqual(stats.skipped_files, 0) - self.assertEqual(stats.downloaded_files, 0) - self.assertEqual(stats.failed_files, []) - self.assertEqual(stats.total_bytes_downloaded, 0) - - async def test_get_files_list(self): - """Test getting list of files.""" - start_date = datetime(2024, 1, 1) - end_date = datetime(2024, 1, 31) - files = await self.dataset._get_files_list(start_date, end_date) - self.assertIsInstance(files, list) - # Note: We can't assert exact file count as it depends on the API response - - async def test_download(self): - """Test file download functionality.""" - # Test with a small date range and limit to 1 file - start_date = datetime(2024, 1, 1, 0, 0, 0) - end_date = datetime(2024, 1, 1, 0, 30, 0) # Just 30 minutes - - try: - await self.dataset.download(start_date, end_date, limit=1) - # Check if stats were updated - self.assertEqual(self.dataset.stats.total_files, 1, "Should only download 1 file") - self.assertLessEqual(self.dataset.stats.downloaded_files + self.dataset.stats.skipped_files, 1) - self.assertEqual(len(self.dataset.stats.failed_files), 0, "No files should fail") - except Exception as e: - self.fail(f"Download failed with error: {str(e)}") - - async def test_download_with_limit(self): - """Test download limit functionality.""" - start_date = datetime(2024, 1, 1, 0, 0, 0) - end_date = datetime(2024, 1, 1, 23, 59, 59) # Full day - - # Test with different limits - for limit in [1, 2]: - with self.subTest(limit=limit): - # Create a separate temp dir for each test to avoid conflicts - temp_dir = Path(tempfile.mkdtemp()) - dataset = Downloader( - dataset_name="Actuele10mindataKNMIstations", - version="2", - max_concurrent=10, - api_key=self.api_key, - output_dir=temp_dir - ) - try: - await dataset.download(start_date, end_date, limit=limit) - self.assertEqual(dataset.stats.total_files, limit, - f"Should limit to {limit} files") - self.assertLessEqual( - dataset.stats.downloaded_files + dataset.stats.skipped_files, - limit, - f"Total processed files should not exceed limit of {limit}" - ) - except Exception as e: - self.fail(f"Download with limit {limit} failed with error: {str(e)}") - finally: - await dataset.http_client.aclose() - # Clean up the temp directory - shutil.rmtree(temp_dir, ignore_errors=True) - - async def asyncTearDown(self): - """Clean up after tests.""" - if hasattr(self, 'dataset') and self.dataset.http_client: - await self.dataset.http_client.aclose() - # Clean up the temp directory - if hasattr(self, 'temp_dir'): - shutil.rmtree(self.temp_dir, ignore_errors=True) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file From 3d284adb857986f0ab15650f4ede2a3873687223 Mon Sep 17 00:00:00 2001 From: Tibor Casteleijn Date: Tue, 14 Jan 2025 21:59:13 +0100 Subject: [PATCH 2/2] bump version to 1.8.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7eda549..ada7342 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "knmi-dataset-downloader" -version = "1.7.0" +version = "1.8.0" description = "A downloader for KNMI weather datasets" readme = "README.md" authors = [