From c1b83102451757c936b2c5d4697c31319e6bb32e Mon Sep 17 00:00:00 2001 From: Tyler Barrus Date: Sat, 13 Jan 2024 16:51:53 -0500 Subject: [PATCH] Quotient Filter: merge (#116) --- CHANGELOG.md | 8 +++ probables/exceptions.py | 11 ++++ probables/quotientfilter/quotientfilter.py | 48 ++++++++++++----- tests/quotientfilter_test.py | 61 ++++++++++++++++++---- 4 files changed, 103 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0dd8b03..6cbec4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # PyProbables Changelog +### Version 0.6.1 + +* Quotient Filter: + * Add ability to get hashes from the filter either as a list, or as a generator + * Add quotient filter expand capability, auto and on request + * Add QuotientFilterError exception + * Add merge functionality + ### Version 0.6.0 * Add `QuotientFilter` implementation; [see issue #37](https://github.com/barrust/pyprobables/issues/37) diff --git a/probables/exceptions.py b/probables/exceptions.py index ff516e4..b76eb37 100644 --- a/probables/exceptions.py +++ b/probables/exceptions.py @@ -68,3 +68,14 @@ class CountMinSketchError(ProbablesBaseException): def __init__(self, message: str) -> None: self.message = message super().__init__(self.message) + + +class QuotientFilterError(ProbablesBaseException): + """Quotient Filter Exception + + Args: + message (str): The error message to be reported""" + + def __init__(self, message: str) -> None: + self.message = message + super().__init__(self.message) diff --git a/probables/quotientfilter/quotientfilter.py b/probables/quotientfilter/quotientfilter.py index 7f5fce9..5f70056 100644 --- a/probables/quotientfilter/quotientfilter.py +++ b/probables/quotientfilter/quotientfilter.py @@ -6,6 +6,7 @@ from array import array from typing import Iterator, List, Optional +from probables.exceptions import QuotientFilterError from probables.hashes import KeyT, SimpleHashT, fnv_1a_32 from probables.utilities import Bitarray @@ -20,7 +21,7 @@ class QuotientFilter: Returns: QuotientFilter: The initialized filter Raises: - ValueError: + QuotientFilterError: Raised when unable to initialize Note: The size of the QuotientFilter will be 2**q""" @@ -44,8 +45,8 @@ def __init__( self, quotient: int = 20, auto_expand: bool = True, hash_function: Optional[SimpleHashT] = None ): # needs to be parameterized if quotient < 3 or quotient > 31: - raise ValueError( - f"Quotient filter: Invalid quotient setting; quotient must be between 3 and 31; {quotient} was provided" + raise QuotientFilterError( + f"Invalid quotient setting; quotient must be between 3 and 31; {quotient} was provided" ) self.__set_params(quotient, auto_expand, hash_function) @@ -140,7 +141,9 @@ def add(self, key: KeyT) -> None: """Add key to the quotient filter Args: - key (str|bytes): The element to add""" + key (str|bytes): The element to add + Raises: + QuotientFilterError: Raised when no locations are available in which to insert""" _hash = self._hash_func(key, 0) self.add_alt(_hash) @@ -148,12 +151,14 @@ def add_alt(self, _hash: int) -> None: """Add the pre-hashed value to the quotient filter Args: - _hash (int): The element to add""" + _hash (int): The element to add + Raises: + QuotientFilterError: Raised when no locations are available in which to insert""" + if self._auto_resize and self.load_factor >= self._max_load_factor: + self.resize() key_quotient = _hash >> self._r key_remainder = _hash & ((1 << self._r) - 1) if self._contained_at_loc(key_quotient, key_remainder) == -1: - if self._auto_resize and self.load_factor >= self._max_load_factor: - self.resize() self._add(key_quotient, key_remainder) def check(self, key: KeyT) -> bool: @@ -177,7 +182,7 @@ def check_alt(self, _hash: int) -> bool: key_remainder = _hash & ((1 << self._r) - 1) return not self._contained_at_loc(key_quotient, key_remainder) == -1 - def iter_hashes(self) -> Iterator[int]: + def hashes(self) -> Iterator[int]: """A generator over the hashes in the quotient filter Yields: @@ -220,25 +225,25 @@ def get_hashes(self) -> List[int]: Returns: list(int): The hash values stored in the quotient filter""" - return list(self.iter_hashes()) + return list(self.hashes()) def resize(self, quotient: Optional[int] = None) -> None: """Resize the quotient filter to use the new quotient size Args: - int: The new quotient to use + quotient (int): The new quotient to use Note: If `None` is provided, the quotient filter will double in size (quotient + 1) Raises: - ValueError: When the new quotient will not accommodate the elements already added""" + QuotientFilterError: When the new quotient will not accommodate the elements already added""" if quotient is None: quotient = self._q + 1 if self.elements_added >= (1 << quotient): - raise ValueError("Unable to shrink since there will be too many elements in the quotient filter") + raise QuotientFilterError("Unable to shrink since there will be too many elements in the quotient filter") if quotient < 3 or quotient > 31: - raise ValueError( - f"Quotient filter: Invalid quotient setting; quotient must be between 3 and 31; {quotient} was provided" + raise QuotientFilterError( + f"Invalid quotient setting; quotient must be between 3 and 31; {quotient} was provided" ) hashes = self.get_hashes() @@ -251,6 +256,19 @@ def resize(self, quotient: Optional[int] = None) -> None: for _h in hashes: self.add_alt(_h) + def merge(self, second: "QuotientFilter") -> None: + """Merge the `second` quotient filter into the first + + Args: + second (QuotientFilter): The quotient filter to merge + Note: + The hashing function between the two filters should match + Note: + Errors can occur if the quotient filter being inserted into does not expand (i.e., auto_expand=False)""" + + for _h in second.hashes(): + self.add_alt(_h) + def _shift_insert(self, k, v, start, j, flag): if self._is_occupied[j] == 0 and self._is_continuation[j] == 0 and self._is_shifted[j] == 0: self._filter[j] = v @@ -311,6 +329,8 @@ def _get_start_index(self, k): return j def _add(self, q: int, r: int): + if self._size == self._elements_added: + raise QuotientFilterError("Unable to insert the element due to insufficient space") if self._is_occupied[q] == 0 and self._is_continuation[q] == 0 and self._is_shifted[q] == 0: self._filter[q] = r self._is_occupied[q] = 1 diff --git a/tests/quotientfilter_test.py b/tests/quotientfilter_test.py index 292c5ba..602ad8d 100644 --- a/tests/quotientfilter_test.py +++ b/tests/quotientfilter_test.py @@ -9,6 +9,8 @@ from pathlib import Path from tempfile import NamedTemporaryFile +from probables.exceptions import QuotientFilterError + this_dir = Path(__file__).parent sys.path.insert(0, str(this_dir)) sys.path.insert(0, str(this_dir.parent)) @@ -49,6 +51,10 @@ def test_qf_init(self): self.assertEqual(qf.num_elements, 16777216) # 2**qf.quotient self.assertFalse(qf.auto_expand) + # reset auto_expand + qf.auto_expand = True + self.assertTrue(qf.auto_expand) + def test_qf_add_check(self): "test that the qf is able to add and check elements" qf = QuotientFilter(quotient=8) @@ -91,10 +97,10 @@ def test_qf_add_check_in(self): def test_qf_init_errors(self): """test quotient filter initialization errors""" - self.assertRaises(ValueError, lambda: QuotientFilter(quotient=2)) - self.assertRaises(ValueError, lambda: QuotientFilter(quotient=32)) + self.assertRaises(QuotientFilterError, lambda: QuotientFilter(quotient=2)) + self.assertRaises(QuotientFilterError, lambda: QuotientFilter(quotient=32)) - def test_retrieve_hashes(self): + def test_qf_retrieve_hashes(self): """test retrieving hashes back from the quotient filter""" qf = QuotientFilter(quotient=8, auto_expand=False) hashes = [] @@ -107,7 +113,7 @@ def test_retrieve_hashes(self): self.assertEqual(qf.elements_added, len(out_hashes)) self.assertEqual(set(hashes), set(out_hashes)) - def test_resize(self): + def test_qf_resize(self): """test resizing the quotient filter""" qf = QuotientFilter(quotient=8, auto_expand=False) for i in range(200): @@ -120,7 +126,7 @@ def test_resize(self): self.assertEqual(qf.bits_per_elm, 32) self.assertFalse(qf.auto_expand) - self.assertRaises(ValueError, lambda: qf.resize(7)) # should be too small to fit + self.assertRaises(QuotientFilterError, lambda: qf.resize(7)) # should be too small to fit qf.resize(17) self.assertEqual(qf.elements_added, 200) @@ -132,7 +138,7 @@ def test_resize(self): for i in range(200): self.assertTrue(qf.check(str(i))) - def test_auto_resize(self): + def test_qf_auto_resize(self): """test resizing the quotient filter automatically""" qf = QuotientFilter(quotient=8, auto_expand=True) self.assertEqual(qf.max_load_factor, 0.85) @@ -153,7 +159,7 @@ def test_auto_resize(self): self.assertEqual(qf.remainder, 23) self.assertEqual(qf.bits_per_elm, 32) - def test_auto_resize_changed_max_load_factor(self): + def test_qf_auto_resize_changed_max_load_factor(self): """test resizing the quotient filter with a different load factor""" qf = QuotientFilter(quotient=8, auto_expand=True) self.assertEqual(qf.max_load_factor, 0.85) @@ -178,13 +184,46 @@ def test_auto_resize_changed_max_load_factor(self): self.assertEqual(qf.remainder, 23) self.assertEqual(qf.bits_per_elm, 32) - def test_resize_errors(self): + def test_qf_resize_errors(self): """test resizing errors""" qf = QuotientFilter(quotient=8, auto_expand=True) for i in range(200): qf.add(str(i)) - self.assertRaises(ValueError, lambda: qf.resize(quotient=2)) - self.assertRaises(ValueError, lambda: qf.resize(quotient=32)) - self.assertRaises(ValueError, lambda: qf.resize(quotient=6)) + self.assertRaises(QuotientFilterError, lambda: qf.resize(quotient=2)) + self.assertRaises(QuotientFilterError, lambda: qf.resize(quotient=32)) + self.assertRaises(QuotientFilterError, lambda: qf.resize(quotient=6)) + + def test_qf_merge(self): + """test merging two quotient filters together""" + qf = QuotientFilter(quotient=8, auto_expand=True) + for i in range(200): + qf.add(str(i)) + + fq = QuotientFilter(quotient=8) + for i in range(300, 500): + fq.add(str(i)) + + qf.merge(fq) + + for i in range(200): + self.assertTrue(qf.check(str(i))) + for i in range(200, 300): + self.assertFalse(qf.check(str(i))) + for i in range(300, 500): + self.assertTrue(qf.check(str(i))) + + self.assertEqual(qf.elements_added, 400) + + def test_qf_merge_error(self): + """test unable to merge due to inability to grow""" + qf = QuotientFilter(quotient=8, auto_expand=False) + for i in range(200): + qf.add(str(i)) + + fq = QuotientFilter(quotient=8) + for i in range(300, 400): + fq.add(str(i)) + + self.assertRaises(QuotientFilterError, lambda: qf.merge(fq))