From e35b9cbeecacc7a251f1701bf2afa71b69083063 Mon Sep 17 00:00:00 2001 From: Kalyan Dutia Date: Tue, 14 Jan 2025 12:28:40 +0300 Subject: [PATCH] implement acronym expansion in search (#171) * implement acronym replacement in search * rewrite unit tests, adding another example * add tests for exact match search * fix failing test * bump version to 1.13.0 * fix failing test * set acronym replacement to off by default until it's tested * address review comments --- src/cpr_sdk/models/search.py | 6 + src/cpr_sdk/version.py | 4 +- src/cpr_sdk/vespa.py | 9 ++ tests/local_vespa/test_app/rules/acronyms.sr | 152 ++++++++++++++++++ .../test_documents/family_document.json | 4 +- tests/test_search_adaptors.py | 77 ++++++++- 6 files changed, 245 insertions(+), 7 deletions(-) create mode 100644 tests/local_vespa/test_app/rules/acronyms.sr diff --git a/src/cpr_sdk/models/search.py b/src/cpr_sdk/models/search.py index e543430e..496a7392 100644 --- a/src/cpr_sdk/models/search.py +++ b/src/cpr_sdk/models/search.py @@ -257,6 +257,12 @@ class SearchParameters(BaseModel): so can also be used to override YQL or ranking profiles. """ + replace_acronyms: bool = False + """ + Whether to perform acronym replacement based on the 'acronyms' ruleset. + See docs: https://docs.vespa.ai/en/query-rewriting.html#rule-bases + """ + @model_validator(mode="after") def validate(self): """Validate against mutually exclusive fields""" diff --git a/src/cpr_sdk/version.py b/src/cpr_sdk/version.py index 916594ff..cd9dba46 100644 --- a/src/cpr_sdk/version.py +++ b/src/cpr_sdk/version.py @@ -1,6 +1,6 @@ _MAJOR = "1" -_MINOR = "12" -_PATCH = "1" +_MINOR = "13" +_PATCH = "0" _SUFFIX = "" VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) diff --git a/src/cpr_sdk/vespa.py b/src/cpr_sdk/vespa.py index d5563e31..7dc9fa28 100644 --- a/src/cpr_sdk/vespa.py +++ b/src/cpr_sdk/vespa.py @@ -115,6 +115,15 @@ def build_vespa_request_body(parameters: SearchParameters) -> dict[str, str]: vespa_request_body = vespa_request_body | parameters.custom_vespa_request_body + if parameters.replace_acronyms: + if parameters.exact_match: + _LOGGER.warning( + "Exact match and replace_acronyms are incompatible. Ignoring replace_acronyms." + ) + else: + vespa_request_body["rules.off"] = False + vespa_request_body["rules.rulebase"] = "acronyms" + # Disabling embedding search for descriptions vespa_request_body["input.query(description_closeness_weight)"] = 0 diff --git a/tests/local_vespa/test_app/rules/acronyms.sr b/tests/local_vespa/test_app/rules/acronyms.sr new file mode 100644 index 00000000..36092de0 --- /dev/null +++ b/tests/local_vespa/test_app/rules/acronyms.sr @@ -0,0 +1,152 @@ +bipoc +> "black indigenous and people of colour"; +black indigenous and people of colour +> bipoc; + +bur +> "biennial update report"; +biennial update report +> bur; + +c2h6 +> "ethane"; +ethane +> c2h6; + +c3h8 +> "propane"; +propane +> c3h8; + +c4h10 +> "butane"; +butane +> c4h10; + +cbam +> "carbon border adjustment mechanism"; +carbon border adjustment mechanism +> cbam; + +cbews +> "community based early warning system"; +community based early warning system +> cbews; + +cfc +> "chlorofluorocarbon"; +chlorofluorocarbon +> cfc; + +cfcs +> "chlorofluorocarbons"; +chlorofluorocarbons +> cfcs; + +ch4 +> "methane"; +methane +> ch4; + +co2 +> "carbon dioxide"; +carbon dioxide +> co2; + +csrd +> "corporate sustainability reporting directive"; +corporate sustainability reporting directive +> csrd; + +dfi +> "development finance institutions"; +development finance institutions +> dfi; + +drm +> "disaster risk management"; +disaster risk management +> drm; + +erf +> "emission reduction fund"; +emission reduction fund +> erf; + +ets +> "emission trading system"; +emission trading system +> ets; + +ews +> "early warning systems"; +early warning systems +> ews; + +fgm +> "female genital mutilation"; +female genital mutilation +> fgm; + +gga +> "global goal on adaptation"; +global goal on adaptation +> gga; + +ghg +> "greenhouse gas"; +greenhouse gas +> ghg; + +glof +> "glacial lake outburst flood"; +glacial lake outburst flood +> glof; + +glofs +> "glacial lake outburst floods"; +glacial lake outburst floods +> glofs; + +glp +> "liquefied petroleum gas"; +liquefied petroleum gas +> glp; + +gst +> "global stocktake"; +global stocktake +> gst; + +hcfc +> "hydrochlorofluorocarbon"; +hydrochlorofluorocarbon +> hcfc; + +hcfcs +> "hydrochlorofluorocarbons"; +hydrochlorofluorocarbons +> hcfcs; + +hfc +> "hydrofluorocarbons"; +hydrofluorocarbons +> hfc; + +ifrs +> "international financial reporting standards"; +international financial reporting standards +> ifrs; + +indc +> "initial nationally determined contribution"; +initial nationally determined contribution +> indc; + +ipcc +> "intergovernmental panel on climate change"; +intergovernmental panel on climate change +> ipcc; + +lez +> "low emission zone"; +low emission zone +> lez; + +lng +> "liquified natural gas"; +liquified natural gas +> lng; + +lpg +> "liquefied petroleum gas"; +liquefied petroleum gas +> lpg; + +mhews +> "multi hazard early warning systems"; +multi hazard early warning systems +> mhews; + +n2o +> "nitrous oxide"; +nitrous oxide +> n2o; + +ndc +> "nationally determined contribution"; +nationally determined contribution +> ndc; + +nf3 +> "nitrogen trifluoride"; +nitrogen trifluoride +> nf3; + +ngo +> "non governmental organisation"; +non governmental organisation +> ngo; + +nh3 +> "ammonia"; +ammonia +> nh3; + +o2 +> "oxygen"; +oxygen +> o2; + +o3 +> "ozone"; +ozone +> o3; + +pfc +> "perfluorocarbon"; +perfluorocarbon +> pfc; + +pfcs +> "perfluorocarbons"; +perfluorocarbons +> pfcs; + +ril +> "reduced impact logging"; +reduced impact logging +> ril; + +sdg +> "sustainable development goal"; +sustainable development goal +> sdg; + +sf6 +> "sulphur hexafluoride"; +sulphur hexafluoride +> sf6; + +slr +> "sea level rise"; +sea level rise +> slr; + +ulez +> "ultra low emission zone"; +ultra low emission zone +> ulez; + +wfp +> "world food programme"; +world food programme +> wfp; + +zev +> "zero emissons vehicle"; +zero emissons vehicle +> zev; + +pv +> "photovoltaic"; +photovoltaic +> pv; \ No newline at end of file diff --git a/tests/local_vespa/test_documents/family_document.json b/tests/local_vespa/test_documents/family_document.json index 644c605c..c5391783 100644 --- a/tests/local_vespa/test_documents/family_document.json +++ b/tests/local_vespa/test_documents/family_document.json @@ -4,7 +4,7 @@ "fields": { "family_source": "CCLW", "search_weights_ref": "id:doc_search:search_weights::default_weights", - "family_name": "Climate Change Adaptation and Low Emissions Growth Strategy by 2035", + "family_name": "Nationally Determined Contribution: Climate Change Adaptation and Low Emissions Growth Strategy by 2035", "document_title": null, "document_content_type": "text/html", "family_slug": "climate-change-adaptation-and-low-emissions-growth-strategy-by-2035_75e3", @@ -12,7 +12,7 @@ "family_geography": "BIH", "family_geographies": ["BIH", "NOR"], "family_category": "Executive", - "family_name_index": "Climate Change Adaptation and Low Emissions Growth Strategy by 2035", + "family_name_index": "Nationally Determined Contribution: Climate Change Adaptation and Low Emissions Growth Strategy by 2035", "document_languages": ["English"], "document_slug": "climate-change-adaptation-and-low-emissions-growth-strategy-by-2035_6c4c", "family_description_embedding": { diff --git a/tests/test_search_adaptors.py b/tests/test_search_adaptors.py index d700ef6d..f8740e37 100644 --- a/tests/test_search_adaptors.py +++ b/tests/test_search_adaptors.py @@ -172,7 +172,7 @@ def test_vespa_search_adaptor__bad_query_string_still_works(test_vespa): @pytest.mark.vespa def test_vespa_search_adaptor__hybrid(test_vespa): - family_name = "Climate Change Adaptation and Low Emissions Growth Strategy by 2035" + family_name = "Nationally Determined Contribution: Climate Change Adaptation and Low Emissions Growth Strategy by 2035" request = SearchParameters(query_string=family_name) response = vespa_search(test_vespa, request) @@ -180,8 +180,7 @@ def test_vespa_search_adaptor__hybrid(test_vespa): # Note that this is a fairly loose test got_family_names = [] for fam in response.families: - for doc in fam.hits: - got_family_names.append(doc.family_name) + got_family_names.append(fam.hits[0].family_name) assert family_name in got_family_names @@ -774,3 +773,75 @@ def test_vespa_search_hybrid_no_closeness_profile(test_vespa): ) assert response_no_closeness == response_null_closeness_weights + + +@pytest.mark.vespa +def test_acronym_replacement(test_vespa): + ndc_response = vespa_search( + test_vespa, + SearchParameters( + query_string="ndc", + replace_acronyms=True, + ), + ) + + ndc_response_no_replacement = vespa_search( + test_vespa, + SearchParameters( + query_string="ndc", + replace_acronyms=False, + ), + ) + + assert "Nationally Determined Contribution" in str( + ndc_response.families[0].hits[0].family_name + ) + assert "Nationally Determined Contribution" not in str( + ndc_response_no_replacement.families[0].hits[0].family_name + ) + + methane_ch4_response = vespa_search( + test_vespa, + SearchParameters( + query_string="ch4", + replace_acronyms=True, + ), + ) + methane_ch4_response_no_replacement = vespa_search( + test_vespa, + SearchParameters( + query_string="ch4", + replace_acronyms=False, + ), + ) + + assert isinstance(methane_ch4_response.families[0].hits[0], Passage) + assert "methane" in methane_ch4_response.families[0].hits[0].text_block.lower() + + assert ( + not ( + isinstance(methane_ch4_response_no_replacement.families[0].hits[0], Passage) + ) + or "methane" + not in methane_ch4_response_no_replacement.families[0] + .hits[0] + .text_block.lower() + ) + + +@pytest.mark.vespa +def test_acronym_replacement_exact_match_search(test_vespa, caplog): + """Acronym replacement should not run on exact match searches""" + + # There are no exact matches for the query "ndc" in the test data + ndc_response = vespa_search( + test_vespa, + SearchParameters( + query_string="ndc", + exact_match=True, + replace_acronyms=True, + ), + ) + + assert "Exact match and replace_acronyms are incompatible." in caplog.text + assert len(ndc_response.families) == 0