diff --git a/Dockerfile b/Dockerfile index eddf4e13..4a5e73c0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,7 +4,7 @@ # syntax=docker/dockerfile:1.3 ARG BASE_IMG=nvcr.io/nvidia/cuda -ARG BASE_IMG_TAG=12.4.1-base-ubuntu22.04 +ARG BASE_IMG_TAG=12.5.1-base-ubuntu22.04 # Use NVIDIA Morpheus as the base image FROM $BASE_IMG:$BASE_IMG_TAG AS base @@ -21,13 +21,20 @@ LABEL git_commit=$GIT_COMMIT # Install necessary dependencies using apt-get RUN apt-get update && apt-get install -y \ - wget \ bzip2 \ ca-certificates \ curl \ libgl1-mesa-glx \ + software-properties-common \ + wget \ && apt-get clean +# A workaround for the error (mrc-core): /usr/lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.32' not found +# Issue: https://github.com/NVIDIA/nv-ingest/issues/474 +RUN add-apt-repository -y ppa:ubuntu-toolchain-r/test \ + && apt-get update \ + && apt-get install -y --only-upgrade libstdc++6 + RUN wget -O Miniforge3.sh "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh" -O /tmp/miniforge.sh \ && bash /tmp/miniforge.sh -b -p /opt/conda \ && rm /tmp/miniforge.sh diff --git a/src/nv_ingest/extraction_workflows/pdf/nemoretriever_parse_helper.py b/src/nv_ingest/extraction_workflows/pdf/nemoretriever_parse_helper.py index d4068bb4..1e607169 100644 --- a/src/nv_ingest/extraction_workflows/pdf/nemoretriever_parse_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/nemoretriever_parse_helper.py @@ -96,9 +96,6 @@ def nemoretriever_parse( """ logger.debug("Extracting PDF with nemoretriever_parse backend.") - nemoretriever_parse_config = kwargs.get("nemoretriever_parse_config", {}) - nemoretriever_parse_config = nemoretriever_parse_config if nemoretriever_parse_config is not None else {} - row_data = kwargs.get("row_data") # get source_id source_id = row_data["source_id"] @@ -111,9 +108,10 @@ def nemoretriever_parse( paddle_output_format = kwargs.get("paddle_output_format", "pseudo_markdown") paddle_output_format = TableFormatEnum[paddle_output_format.upper()] - pdfium_config = kwargs.get("pdfium_config", {}) - if isinstance(pdfium_config, dict): - pdfium_config = PDFiumConfigSchema(**pdfium_config) + if (extract_tables_method == "yolox") and (extract_tables or extract_charts): + pdfium_config = kwargs.get("pdfium_config", {}) + if isinstance(pdfium_config, dict): + pdfium_config = PDFiumConfigSchema(**pdfium_config) nemoretriever_parse_config = kwargs.get("nemoretriever_parse_config", {}) if isinstance(nemoretriever_parse_config, dict): nemoretriever_parse_config = NemoRetrieverParseConfigSchema(**nemoretriever_parse_config) diff --git a/src/nv_ingest/schemas/pdf_extractor_schema.py b/src/nv_ingest/schemas/pdf_extractor_schema.py index dceb9330..1e03a708 100644 --- a/src/nv_ingest/schemas/pdf_extractor_schema.py +++ b/src/nv_ingest/schemas/pdf_extractor_schema.py @@ -73,7 +73,7 @@ def validate_endpoints(cls, values): for model_name in ["yolox"]: endpoint_name = f"{model_name}_endpoints" - grpc_service, http_service = values.get(endpoint_name) + grpc_service, http_service = values.get(endpoint_name, ("", "")) grpc_service = _clean_service(grpc_service) http_service = _clean_service(http_service) @@ -156,7 +156,7 @@ def validate_endpoints(cls, values): for model_name in ["nemoretriever_parse"]: endpoint_name = f"{model_name}_endpoints" - grpc_service, http_service = values.get(endpoint_name) + grpc_service, http_service = values.get(endpoint_name, ("", "")) grpc_service = _clean_service(grpc_service) http_service = _clean_service(http_service) diff --git a/src/nv_ingest/stages/nim/chart_extraction.py b/src/nv_ingest/stages/nim/chart_extraction.py index 10771714..24217e6b 100644 --- a/src/nv_ingest/stages/nim/chart_extraction.py +++ b/src/nv_ingest/stages/nim/chart_extraction.py @@ -66,9 +66,6 @@ def _update_metadata( # Image is too small; mark as skipped. results[i] = (img, None) - if not valid_images: - return results - # Prepare data payloads for both clients. data_yolox = {"images": valid_arrays} data_paddle = {"base64_images": valid_images} diff --git a/src/nv_ingest/stages/nim/table_extraction.py b/src/nv_ingest/stages/nim/table_extraction.py index c914ab5a..f699a81d 100644 --- a/src/nv_ingest/stages/nim/table_extraction.py +++ b/src/nv_ingest/stages/nim/table_extraction.py @@ -54,7 +54,7 @@ def _update_metadata( logger.debug(f"Running table extraction using protocol {paddle_client.protocol}") # Initialize the results list in the same order as base64_images. - results: List[Optional[Tuple[str, Tuple[Any, Any, Any]]]] = ["", (None, None, None)] * len(base64_images) + results: List[Optional[Tuple[str, Tuple[Any, Any, Any]]]] = [("", None, None, None)] * len(base64_images) valid_images: List[str] = [] valid_indices: List[int] = [] @@ -70,7 +70,7 @@ def _update_metadata( valid_indices.append(i) else: # Image is too small; mark as skipped. - results[i] = ("", None, None, None) + results[i] = (img, None, None, None) if not valid_images: return results diff --git a/tests/nv_ingest/extraction_workflows/pdf/test_nemoretriever_parse_helper.py b/tests/nv_ingest/extraction_workflows/pdf/test_nemoretriever_parse_helper.py index 7853d505..c59f1e0d 100644 --- a/tests/nv_ingest/extraction_workflows/pdf/test_nemoretriever_parse_helper.py +++ b/tests/nv_ingest/extraction_workflows/pdf/test_nemoretriever_parse_helper.py @@ -34,16 +34,25 @@ def sample_pdf_stream(): return pdf_stream +@pytest.fixture +def mock_parser_config(): + return { + "nemoretriever_parse_endpoints": ("parser:8001", "http://parser:8000"), + } + + @patch(f"{_MODULE_UNDER_TEST}.create_inference_client") -def test_nemoretriever_parse_text_extraction(mock_client, sample_pdf_stream, document_df): +def test_nemoretriever_parse_text_extraction(mock_client, sample_pdf_stream, document_df, mock_parser_config): mock_client_instance = MagicMock() mock_client.return_value = mock_client_instance mock_client_instance.infer.return_value = [ - { - "bbox": {"xmin": 0.16633729456384325, "ymin": 0.0969, "xmax": 0.3097820480404551, "ymax": 0.1102}, - "text": "testing", - "type": "Text", - } + [ + { + "bbox": {"xmin": 0.16633729456384325, "ymin": 0.0969, "xmax": 0.3097820480404551, "ymax": 0.1102}, + "text": "testing", + "type": "Text", + } + ] ] result = nemoretriever_parse( @@ -51,9 +60,11 @@ def test_nemoretriever_parse_text_extraction(mock_client, sample_pdf_stream, doc extract_text=True, extract_images=False, extract_tables=False, + extract_charts=False, row_data=document_df.iloc[0], text_depth="page", - nemoretriever_parse_config=MagicMock(), + extract_tables_method="nemoretriever_parse", + nemoretriever_parse_config=mock_parser_config, ) assert len(result) == 1 @@ -63,15 +74,17 @@ def test_nemoretriever_parse_text_extraction(mock_client, sample_pdf_stream, doc @patch(f"{_MODULE_UNDER_TEST}.create_inference_client") -def test_nemoretriever_parse_table_extraction(mock_client, sample_pdf_stream, document_df): +def test_nemoretriever_parse_table_extraction(mock_client, sample_pdf_stream, document_df, mock_parser_config): mock_client_instance = MagicMock() mock_client.return_value = mock_client_instance mock_client_instance.infer.return_value = [ - { - "bbox": {"xmin": 1 / 1024, "ymin": 2 / 1280, "xmax": 101 / 1024, "ymax": 102 / 1280}, - "text": "table text", - "type": "Table", - } + [ + { + "bbox": {"xmin": 1 / 1024, "ymin": 2 / 1280, "xmax": 101 / 1024, "ymax": 102 / 1280}, + "text": "table text", + "type": "Table", + } + ] ] result = nemoretriever_parse( @@ -79,9 +92,11 @@ def test_nemoretriever_parse_table_extraction(mock_client, sample_pdf_stream, do extract_text=True, extract_images=False, extract_tables=True, + extract_charts=False, row_data=document_df.iloc[0], text_depth="page", - nemoretriever_parse_config=MagicMock(), + extract_tables_method="nemoretriever_parse", + nemoretriever_parse_config=mock_parser_config, ) assert len(result) == 2 @@ -93,15 +108,17 @@ def test_nemoretriever_parse_table_extraction(mock_client, sample_pdf_stream, do @patch(f"{_MODULE_UNDER_TEST}.create_inference_client") -def test_nemoretriever_parse_image_extraction(mock_client, sample_pdf_stream, document_df): +def test_nemoretriever_parse_image_extraction(mock_client, sample_pdf_stream, document_df, mock_parser_config): mock_client_instance = MagicMock() mock_client.return_value = mock_client_instance mock_client_instance.infer.return_value = [ - { - "bbox": {"xmin": 1 / 1024, "ymin": 2 / 1280, "xmax": 101 / 1024, "ymax": 102 / 1280}, - "text": "", - "type": "Picture", - } + [ + { + "bbox": {"xmin": 1 / 1024, "ymin": 2 / 1280, "xmax": 101 / 1024, "ymax": 102 / 1280}, + "text": "", + "type": "Picture", + } + ] ] result = nemoretriever_parse( @@ -109,9 +126,11 @@ def test_nemoretriever_parse_image_extraction(mock_client, sample_pdf_stream, do extract_text=True, extract_images=True, extract_tables=False, + extract_charts=False, row_data=document_df.iloc[0], text_depth="page", - nemoretriever_parse_config=MagicMock(), + extract_tables_method="nemoretriever_parse", + nemoretriever_parse_config=mock_parser_config, ) assert len(result) == 2 @@ -123,20 +142,22 @@ def test_nemoretriever_parse_image_extraction(mock_client, sample_pdf_stream, do @patch(f"{_MODULE_UNDER_TEST}.create_inference_client") -def test_nemoretriever_parse_text_extraction_bboxes(mock_client, sample_pdf_stream, document_df): +def test_nemoretriever_parse_text_extraction_bboxes(mock_client, sample_pdf_stream, document_df, mock_parser_config): mock_client_instance = MagicMock() mock_client.return_value = mock_client_instance mock_client_instance.infer.return_value = [ - { - "bbox": {"xmin": 0.16633729456384325, "ymin": 0.0969, "xmax": 0.3097820480404551, "ymax": 0.1102}, - "text": "testing0", - "type": "Title", - }, - { - "bbox": {"xmin": 0.16633729456384325, "ymin": 0.0969, "xmax": 0.3097820480404551, "ymax": 0.1102}, - "text": "testing1", - "type": "Text", - }, + [ + { + "bbox": {"xmin": 0.16633729456384325, "ymin": 0.0969, "xmax": 0.3097820480404551, "ymax": 0.1102}, + "text": "testing0", + "type": "Title", + }, + { + "bbox": {"xmin": 0.16633729456384325, "ymin": 0.0969, "xmax": 0.3097820480404551, "ymax": 0.1102}, + "text": "testing1", + "type": "Text", + }, + ] ] result = nemoretriever_parse( @@ -144,9 +165,11 @@ def test_nemoretriever_parse_text_extraction_bboxes(mock_client, sample_pdf_stre extract_text=True, extract_images=False, extract_tables=False, + extract_charts=False, row_data=document_df.iloc[0], text_depth="page", - nemoretriever_parse_config=MagicMock(), + extract_tables_method="nemoretriever_parse", + nemoretriever_parse_config=mock_parser_config, ) assert len(result) == 1 diff --git a/tests/nv_ingest/schemas/test_table_extractor_schema.py b/tests/nv_ingest/schemas/test_table_extractor_schema.py index bb94a742..c7415f6f 100644 --- a/tests/nv_ingest/schemas/test_table_extractor_schema.py +++ b/tests/nv_ingest/schemas/test_table_extractor_schema.py @@ -7,49 +7,79 @@ # Test cases for TableExtractorConfigSchema def test_valid_config_with_grpc_only(): - config = TableExtractorConfigSchema(auth_token="valid_token", paddle_endpoints=("grpc://paddle_service", None)) + config = TableExtractorConfigSchema( + auth_token="valid_token", + yolox_endpoints=("grpc://yolox_service", None), + paddle_endpoints=("grpc://paddle_service", None), + ) assert config.auth_token == "valid_token" + assert config.yolox_endpoints == ("grpc://yolox_service", None) assert config.paddle_endpoints == ("grpc://paddle_service", None) def test_valid_config_with_http_only(): - config = TableExtractorConfigSchema(auth_token="valid_token", paddle_endpoints=(None, "http://paddle_service")) + config = TableExtractorConfigSchema( + auth_token="valid_token", + yolox_endpoints=(None, "http://yolox_service"), + paddle_endpoints=(None, "http://paddle_service"), + ) assert config.auth_token == "valid_token" + assert config.yolox_endpoints == (None, "http://yolox_service") assert config.paddle_endpoints == (None, "http://paddle_service") def test_valid_config_with_both_services(): config = TableExtractorConfigSchema( - auth_token="valid_token", paddle_endpoints=("grpc://paddle_service", "http://paddle_service") + auth_token="valid_token", + yolox_endpoints=("grpc://yolox_service", "http://yolox_service"), + paddle_endpoints=("grpc://paddle_service", "http://paddle_service"), ) assert config.auth_token == "valid_token" + assert config.yolox_endpoints == ("grpc://yolox_service", "http://yolox_service") assert config.paddle_endpoints == ("grpc://paddle_service", "http://paddle_service") def test_invalid_config_empty_endpoints(): with pytest.raises(ValidationError) as exc_info: - TableExtractorConfigSchema(paddle_endpoints=(None, None)) + TableExtractorConfigSchema( + yolox_endpoints=("grpc://yolox_service", "http://yolox_service"), + paddle_endpoints=(None, None), + ) assert "Both gRPC and HTTP services cannot be empty for paddle_endpoints" in str(exc_info.value) def test_invalid_extra_fields(): with pytest.raises(ValidationError) as exc_info: TableExtractorConfigSchema( - auth_token="valid_token", paddle_endpoints=("grpc://paddle_service", None), extra_field="invalid" + auth_token="valid_token", + yolox_endpoints=("grpc://yolox_service", None), + paddle_endpoints=("grpc://paddle_service", None), + extra_field="invalid", ) assert "Extra inputs are not permitted" in str(exc_info.value) def test_cleaning_empty_strings_in_endpoints(): - config = TableExtractorConfigSchema(paddle_endpoints=(" ", "http://paddle_service")) + config = TableExtractorConfigSchema( + yolox_endpoints=("grpc://yolox_service", " "), + paddle_endpoints=(" ", "http://paddle_service"), + ) + assert config.yolox_endpoints == ("grpc://yolox_service", None) assert config.paddle_endpoints == (None, "http://paddle_service") - config = TableExtractorConfigSchema(paddle_endpoints=("grpc://paddle_service", "")) + config = TableExtractorConfigSchema( + yolox_endpoints=("", "http://yolox_service"), + paddle_endpoints=("grpc://paddle_service", ""), + ) + assert config.yolox_endpoints == (None, "http://yolox_service") assert config.paddle_endpoints == ("grpc://paddle_service", None) def test_auth_token_is_none_by_default(): - config = TableExtractorConfigSchema(paddle_endpoints=("grpc://paddle_service", "http://paddle_service")) + config = TableExtractorConfigSchema( + yolox_endpoints=("grpc://yolox_service", "http://yolox_service"), + paddle_endpoints=("grpc://paddle_service", "http://paddle_service"), + ) assert config.auth_token is None @@ -63,7 +93,10 @@ def test_table_extractor_schema_defaults(): def test_table_extractor_schema_with_custom_values(): - stage_config = TableExtractorConfigSchema(paddle_endpoints=("grpc://paddle_service", "http://paddle_service")) + stage_config = TableExtractorConfigSchema( + yolox_endpoints=("grpc://yolox_service", "http://yolox_service"), + paddle_endpoints=("grpc://paddle_service", "http://paddle_service"), + ) config = TableExtractorSchema(max_queue_size=15, n_workers=12, raise_on_failure=True, stage_config=stage_config) assert config.max_queue_size == 15 assert config.n_workers == 12 diff --git a/tests/nv_ingest/stages/nims/test_chart_extraction.py b/tests/nv_ingest/stages/nims/test_chart_extraction.py index 93872a8b..2c1fe32c 100644 --- a/tests/nv_ingest/stages/nims/test_chart_extraction.py +++ b/tests/nv_ingest/stages/nims/test_chart_extraction.py @@ -120,7 +120,7 @@ def test_update_metadata_single_batch_single_worker(mocker, base64_image): # Patch join_yolox_and_paddle_output so that it returns a dict per image. mock_join = mocker.patch( - f"{MODULE_UNDER_TEST}.join_yolox_and_paddle_output", + f"{MODULE_UNDER_TEST}.join_yolox_graphic_elements_and_paddle_output", side_effect=[{"chart_title": "joined_1"}, {"chart_title": "joined_2"}], ) # Patch process_yolox_graphic_elements to extract the chart title. @@ -176,7 +176,7 @@ def test_update_metadata_multiple_batches_multi_worker(mocker, base64_image): # Patch join_yolox_and_paddle_output so it returns the expected joined dict per image. mock_join = mocker.patch( - f"{MODULE_UNDER_TEST}.join_yolox_and_paddle_output", + f"{MODULE_UNDER_TEST}.join_yolox_graphic_elements_and_paddle_output", side_effect=[{"chart_title": "joined_1"}, {"chart_title": "joined_2"}, {"chart_title": "joined_3"}], ) # Patch process_yolox_graphic_elements to extract the chart title. diff --git a/tests/nv_ingest/stages/nims/test_table_extraction.py b/tests/nv_ingest/stages/nims/test_table_extraction.py index d033f249..e9649c4e 100644 --- a/tests/nv_ingest/stages/nims/test_table_extraction.py +++ b/tests/nv_ingest/stages/nims/test_table_extraction.py @@ -29,6 +29,15 @@ def paddle_mock(): return MagicMock() +@pytest.fixture +def yolox_mock(): + """ + Fixture that returns a MagicMock for the yolox_client, + which we'll pass to _update_metadata. + """ + return MagicMock() + + @pytest.fixture def validated_config(): """ @@ -40,7 +49,8 @@ class FakeStageConfig: # Values that _extract_table_data expects workers_per_progress_engine = 5 auth_token = "fake-token" - # For _create_paddle_client + yolox_endpoints = ("grpc_url", "http_url") + yolox_infer_protocol = "grpc" paddle_endpoints = ("grpc_url", "http_url") paddle_infer_protocol = "grpc" @@ -54,7 +64,7 @@ def test_extract_table_data_empty_df(mocker, validated_config): """ If df is empty, return the df + an empty trace_info without creating a client or calling _update_metadata. """ - mock_create_client = mocker.patch(f"{MODULE_UNDER_TEST}._create_paddle_client") + mock_create_clients = mocker.patch(f"{MODULE_UNDER_TEST}._create_clients") mock_update_metadata = mocker.patch(f"{MODULE_UNDER_TEST}._update_metadata") df_in = pd.DataFrame() @@ -62,7 +72,7 @@ def test_extract_table_data_empty_df(mocker, validated_config): df_out, trace_info = _extract_table_data(df_in, {}, validated_config) assert df_out.empty assert trace_info == {} - mock_create_client.assert_not_called() + mock_create_clients.assert_not_called() mock_update_metadata.assert_not_called() @@ -72,8 +82,8 @@ def test_extract_table_data_no_valid_rows(mocker, validated_config): skip _update_metadata, still create/close the client, and return the DataFrame unmodified with a trace_info. """ - mock_client = MagicMock() - mock_create_client = mocker.patch(f"{MODULE_UNDER_TEST}._create_paddle_client", return_value=mock_client) + mock_clients = (MagicMock(), MagicMock()) + mock_create_clients = mocker.patch(f"{MODULE_UNDER_TEST}._create_clients", return_value=mock_clients) mock_update_metadata = mocker.patch(f"{MODULE_UNDER_TEST}._update_metadata") df_in = pd.DataFrame( @@ -92,9 +102,10 @@ def test_extract_table_data_no_valid_rows(mocker, validated_config): df_out, trace_info = _extract_table_data(df_in, {}, validated_config) assert df_out.equals(df_in) assert "trace_info" in trace_info - mock_create_client.assert_called_once() # We do create a client + mock_create_clients.assert_called_once() # We do create a client mock_update_metadata.assert_not_called() # But never call _update_metadata - mock_client.close.assert_called_once() # Must close client + mock_clients[0].close.assert_called_once() # Must close client + mock_clients[1].close.assert_called_once() # Must close client def test_extract_table_data_all_valid(mocker, validated_config): @@ -102,13 +113,13 @@ def test_extract_table_data_all_valid(mocker, validated_config): All rows are valid => we pass all base64 images to _update_metadata once, then write the returned content/format back into each row. """ - mock_client = MagicMock() - mock_create_client = mocker.patch(f"{MODULE_UNDER_TEST}._create_paddle_client", return_value=mock_client) + mock_clients = (MagicMock(), MagicMock()) + mock_create_clients = mocker.patch(f"{MODULE_UNDER_TEST}._create_clients", return_value=mock_clients) mock_update_metadata = mocker.patch( f"{MODULE_UNDER_TEST}._update_metadata", return_value=[ - ("imgA", [[], ["tableA"]]), - ("imgB", [[], ["tableB"]]), + ("imgA", [], [], ["tableA"]), + ("imgB", [], [], ["tableB"]), ], ) @@ -140,25 +151,28 @@ def test_extract_table_data_all_valid(mocker, validated_config): assert df_out.at[1, "metadata"]["table_metadata"]["table_content_format"] == "simple" # Check calls - mock_create_client.assert_called_once() + mock_create_clients.assert_called_once() mock_update_metadata.assert_called_once_with( base64_images=["imgA", "imgB"], - paddle_client=mock_client, + yolox_client=mock_clients[0], + paddle_client=mock_clients[1], worker_pool_size=validated_config.stage_config.workers_per_progress_engine, + enable_yolox=False, trace_info=trace_info.get("trace_info"), ) - mock_client.close.assert_called_once() + mock_clients[0].close.assert_called_once() + mock_clients[1].close.assert_called_once() def test_extract_table_data_mixed_rows(mocker, validated_config): """ Some rows valid, some invalid => only valid rows get updated. """ - mock_client = MagicMock() - mock_create_client = mocker.patch(f"{MODULE_UNDER_TEST}._create_paddle_client", return_value=mock_client) + mock_clients = (MagicMock(), MagicMock()) + mock_create_clients = mocker.patch(f"{MODULE_UNDER_TEST}._create_clients", return_value=mock_clients) mock_update_metadata = mocker.patch( f"{MODULE_UNDER_TEST}._update_metadata", - return_value=[("good1", [[], ["table1"]]), ("good2", [[], ["table2"]])], + return_value=[("good1", [], [], ["table1"]), ("good2", [], [], ["table2"])], ) df_in = pd.DataFrame( @@ -201,11 +215,14 @@ def test_extract_table_data_mixed_rows(mocker, validated_config): mock_update_metadata.assert_called_once_with( base64_images=["good1", "good2"], - paddle_client=mock_client, + yolox_client=mock_clients[0], + paddle_client=mock_clients[1], worker_pool_size=validated_config.stage_config.workers_per_progress_engine, + enable_yolox=False, trace_info=trace_info.get("trace_info"), ) - mock_client.close.assert_called_once() + mock_clients[0].close.assert_called_once() + mock_clients[1].close.assert_called_once() def test_extract_table_data_update_error(mocker, validated_config): @@ -213,9 +230,9 @@ def test_extract_table_data_update_error(mocker, validated_config): If _update_metadata raises an exception, we should re-raise but still close the paddle_client. """ - # Mock the paddle client so we don't make real calls or wait. - mock_client = MagicMock() - mock_create_client = mocker.patch(f"{MODULE_UNDER_TEST}._create_paddle_client", return_value=mock_client) + # Mock the yolox and paddle clients so we don't make real calls or wait. + mock_clients = (MagicMock(), MagicMock()) + mock_create_clients = mocker.patch(f"{MODULE_UNDER_TEST}._create_clients", return_value=mock_clients) # Mock _update_metadata to raise an error mock_update_metadata = mocker.patch(f"{MODULE_UNDER_TEST}._update_metadata", side_effect=RuntimeError("paddle_err")) @@ -237,26 +254,27 @@ def test_extract_table_data_update_error(mocker, validated_config): _extract_table_data(df_in, {}, validated_config) # Confirm we created a client - mock_create_client.assert_called_once() + mock_create_clients.assert_called_once() # Ensure the paddle_client was closed in the finally block - mock_client.close.assert_called_once() + mock_clients[0].close.assert_called_once() + mock_clients[1].close.assert_called_once() # Confirm _update_metadata was called once with our single row mock_update_metadata.assert_called_once() -def test_update_metadata_empty_list(paddle_mock): +def test_update_metadata_empty_list(yolox_mock, paddle_mock): """ If base64_images is empty, we should return an empty list and never call paddle_mock.infer. """ with patch(f"{MODULE_UNDER_TEST}.base64_to_numpy") as mock_b64: - result = _update_metadata([], paddle_mock) + result = _update_metadata([], yolox_mock, paddle_mock) assert result == [] mock_b64.assert_not_called() paddle_mock.infer.assert_not_called() -def test_update_metadata_all_valid(mocker, paddle_mock): +def test_update_metadata_all_valid(mocker, yolox_mock, paddle_mock): imgs = ["b64imgA", "b64imgB"] # Patch base64_to_numpy so that both images are valid. mock_dim = mocker.patch(f"{MODULE_UNDER_TEST}.base64_to_numpy") @@ -275,10 +293,10 @@ def test_update_metadata_all_valid(mocker, paddle_mock): ("tableB", "fmtB"), ] - res = _update_metadata(imgs, paddle_mock, worker_pool_size=1) + res = _update_metadata(imgs, yolox_mock, paddle_mock, worker_pool_size=1) assert len(res) == 2 - assert res[0] == ("b64imgA", ("tableA", "fmtA")) - assert res[1] == ("b64imgB", ("tableB", "fmtB")) + assert res[0] == ("b64imgA", None, "tableA", "fmtA") + assert res[1] == ("b64imgB", None, "tableB", "fmtB") # Expect one call to infer with all valid images. paddle_mock.infer.assert_called_once_with( @@ -290,7 +308,7 @@ def test_update_metadata_all_valid(mocker, paddle_mock): ) -def test_update_metadata_skip_small(mocker, paddle_mock): +def test_update_metadata_skip_small(mocker, yolox_mock, paddle_mock): """ Some images are below the min dimension => they skip inference and get ("", "") as results. @@ -307,12 +325,12 @@ def test_update_metadata_skip_small(mocker, paddle_mock): paddle_mock.infer.return_value = [("valid_table", "valid_fmt")] - res = _update_metadata(imgs, paddle_mock) + res = _update_metadata(imgs, yolox_mock, paddle_mock) assert len(res) == 2 # The first image is too small and is skipped. - assert res[0] == ("imgSmall", (None, None)) + assert res[0] == ("imgSmall", None, None, None) # The second image is valid and processed. - assert res[1] == ("imgBig", ("valid_table", "valid_fmt")) + assert res[1] == ("imgBig", None, "valid_table", "valid_fmt") paddle_mock.infer.assert_called_once_with( data={"base64_images": ["imgBig"]}, @@ -323,7 +341,7 @@ def test_update_metadata_skip_small(mocker, paddle_mock): ) -def test_update_metadata_multiple_batches(mocker, paddle_mock): +def test_update_metadata_multiple_batches(mocker, yolox_mock, paddle_mock): imgs = ["img1", "img2", "img3"] # Patch base64_to_numpy so that all images are valid. mock_dim = mocker.patch(f"{MODULE_UNDER_TEST}.base64_to_numpy") @@ -343,11 +361,11 @@ def test_update_metadata_multiple_batches(mocker, paddle_mock): ("table3", "fmt3"), ] - res = _update_metadata(imgs, paddle_mock, worker_pool_size=2) + res = _update_metadata(imgs, yolox_mock, paddle_mock, worker_pool_size=2) assert len(res) == 3 - assert res[0] == ("img1", ("table1", "fmt1")) - assert res[1] == ("img2", ("table2", "fmt2")) - assert res[2] == ("img3", ("table3", "fmt3")) + assert res[0] == ("img1", None, "table1", "fmt1") + assert res[1] == ("img2", None, "table2", "fmt2") + assert res[2] == ("img3", None, "table3", "fmt3") # Verify that infer is called only once with all valid images. paddle_mock.infer.assert_called_once_with( @@ -359,7 +377,7 @@ def test_update_metadata_multiple_batches(mocker, paddle_mock): ) -def test_update_metadata_inference_error(mocker, paddle_mock): +def test_update_metadata_inference_error(mocker, yolox_mock, paddle_mock): """ If paddle.infer fails for a batch, all valid images in that batch get ("",""), then we re-raise the exception. @@ -373,13 +391,13 @@ def test_update_metadata_inference_error(mocker, paddle_mock): paddle_mock.infer.side_effect = RuntimeError("paddle error") with pytest.raises(RuntimeError, match="paddle error"): - _update_metadata(imgs, paddle_mock) + _update_metadata(imgs, yolox_mock, paddle_mock) # The code sets them to ("", "") before re-raising # We can’t see final 'res', but that’s the logic. -def test_update_metadata_mismatch_length(mocker, paddle_mock): +def test_update_metadata_mismatch_length(mocker, yolox_mock, paddle_mock): """ If paddle.infer returns fewer or more results than the valid_images => ValueError """ @@ -391,13 +409,13 @@ def test_update_metadata_mismatch_length(mocker, paddle_mock): # We expect 2 results, but get only 1 paddle_mock.infer.return_value = [("tableOnly", "fmtOnly")] - with pytest.raises(ValueError, match="Expected 2 results"): - _update_metadata(imgs, paddle_mock) + with pytest.raises(ValueError, match="Expected 2 paddle results"): + _update_metadata(imgs, yolox_mock, paddle_mock) -def test_update_metadata_non_list_return(mocker, paddle_mock): +def test_update_metadata_non_list_return(mocker, yolox_mock, paddle_mock): """ - If inference returns something that's not a list => ValueError + If inference returns something that's not a list, each gets ("", None, ...). """ imgs = ["imgX"] mock_dim = mocker.patch(f"{MODULE_UNDER_TEST}.base64_to_numpy", return_value=np.zeros((70, 70, 3), dtype=np.uint8)) @@ -406,13 +424,14 @@ def test_update_metadata_non_list_return(mocker, paddle_mock): paddle_mock.infer.return_value = "some_string" - with pytest.raises(ValueError, match="Expected a list of tuples"): - _update_metadata(imgs, paddle_mock) + res = _update_metadata(imgs, yolox_mock, paddle_mock) + assert len(res) == 1 + assert res[0] == ("imgX", None, None, None) -def test_update_metadata_all_small(mocker, paddle_mock): +def test_update_metadata_all_small(mocker, yolox_mock, paddle_mock): """ - If all images are too small, we skip inference entirely and each gets ("",""). + If all images are too small, we skip inference entirely and each gets ("", None, ... """ imgs = ["imgA", "imgB"] mock_dim = mocker.patch(f"{MODULE_UNDER_TEST}.base64_to_numpy") @@ -420,9 +439,10 @@ def test_update_metadata_all_small(mocker, paddle_mock): mocker.patch(f"{MODULE_UNDER_TEST}.PADDLE_MIN_WIDTH", 30) mocker.patch(f"{MODULE_UNDER_TEST}.PADDLE_MIN_HEIGHT", 30) - res = _update_metadata(imgs, paddle_mock) - assert res[0] == ("imgA", (None, None)) - assert res[1] == ("imgB", (None, None)) + res = _update_metadata(imgs, yolox_mock, paddle_mock) + assert len(res) == 2 + assert res[0] == ("imgA", None, None, None) + assert res[1] == ("imgB", None, None, None) # No calls to infer paddle_mock.infer.assert_not_called()