diff --git a/ppa/archive/management/commands/hathi_images.py b/ppa/archive/management/commands/hathi_images.py index 990dbea0..322df71d 100644 --- a/ppa/archive/management/commands/hathi_images.py +++ b/ppa/archive/management/commands/hathi_images.py @@ -108,6 +108,11 @@ def add_arguments(self, parser): nargs="+", help="Optional list of HathiTrust ids (by default, downloads images for all public HathiTrust volumes)", ) + parser.add_argument( + "--collection", + type=str, + help="Filter volumes by provided PPA collection", + ) parser.add_argument( "--image-width", type=int, @@ -150,7 +155,7 @@ def download_image(self, page_url: str, out_file: Path) -> bool: Attempts to download and save an image from the specified URL. Returns a boolean corresponding to whether the download was successful """ - response = requests.get(page_url) + response = self.session.get(page_url) success = False if response.status_code == requests.codes.ok: with out_file.open(mode="wb") as writer: @@ -223,13 +228,18 @@ def handle(self, *args, **kwargs): raise CommandError("Thumbnail width cannot be more than 250 pixels") # use ids specified via command line when present - htids = kwargs.get("htids", []) + htids = kwargs["htids"] # by default, download images for all non-suppressed hathi source ids digworks = DigitizedWork.objects.filter( status=DigitizedWork.PUBLIC, source=DigitizedWork.HATHI ) + # if collection is specified via parameter, use it to filter the querset + collection = kwargs.get("collection") + if collection: + digworks = digworks.filter(collections__name=collection) + # if htids are specified via parameter, use them to filter # the queryset, to ensure we only sync records that are # in the database and not suppressed @@ -250,6 +260,9 @@ def handle(self, *args, **kwargs): f"Downloading images for {n_vols} record{pluralize(digworks)}", ) + # Create requests session + self.session = requests.Session() + # Initialize progress bar if self.show_progress: self.progress_bar = tqdm() diff --git a/ppa/archive/tests/test_hathi_images.py b/ppa/archive/tests/test_hathi_images.py index 3e79cfe8..0ba2af77 100644 --- a/ppa/archive/tests/test_hathi_images.py +++ b/ppa/archive/tests/test_hathi_images.py @@ -62,13 +62,19 @@ def test_log_action(self): def test_log_download(self, mock_log_action): stats = hathi_images.DownloadStats() stats.log_download("image_type") - mock_log_action.called_once_with("image_type", "fetch") + mock_log_action.assert_called_once_with("image_type", "fetch") @patch.object(hathi_images.DownloadStats, "_log_action") def test_log_skip(self, mock_log_action): stats = hathi_images.DownloadStats() - stats.log_download("image_type") - mock_log_action.called_once_with("image_type", "skip") + stats.log_skip("image_type") + mock_log_action.assert_called_once_with("image_type", "skip") + + @patch.object(hathi_images.DownloadStats, "_log_action") + def test_log_error(self, mock_log_action): + stats = hathi_images.DownloadStats() + stats.log_error("image_type") + mock_log_action.assert_called_once_with("image_type", "error") def test_update(self): stats_a = hathi_images.DownloadStats() @@ -120,21 +126,21 @@ def test_interrupt_handler(self, mock_signal): "Ctrl-C / Interrupt to quit immediately\n" ) - @patch("requests.get") - def test_download_image(self, mock_get, tmp_path): + def test_download_image(self, tmp_path): cmd = hathi_images.Command() + cmd.session = Mock() # Not ok status - mock_get.return_value = Mock(status_code=503) + cmd.session.get.return_value = Mock(status_code=503) result = cmd.download_image("page_url", "out_file") - assert mock_get.called_once_with("page_url") + cmd.session.get.assert_called_once_with("page_url") assert result is False # Ok status out_file = tmp_path / "test.jpg" - mock_get.reset_mock() - mock_get.return_value = Mock(status_code=200, content=b"image content") + cmd.session.reset_mock() + cmd.session.get.return_value = Mock(status_code=200, content=b"image content") result = cmd.download_image("page_url", out_file) - assert mock_get.called_once_with("page_url") + cmd.session.get.assert_called_once_with("page_url") assert result is True assert out_file.read_text() == "image content"