diff --git a/probables/quotientfilter/quotientfilter.py b/probables/quotientfilter/quotientfilter.py index 673ccd0..7f5fce9 100644 --- a/probables/quotientfilter/quotientfilter.py +++ b/probables/quotientfilter/quotientfilter.py @@ -15,6 +15,7 @@ class QuotientFilter: Args: quotient (int): The size of the quotient to use + auto_expand (bool): Automatically expand or not hash_function (function): Hashing strategy function to use `hf(key, number)` Returns: QuotientFilter: The initialized filter @@ -35,21 +36,27 @@ class QuotientFilter: "_is_continuation", "_is_shifted", "_filter", + "_max_load_factor", + "_auto_resize", ) - def __init__(self, quotient: int = 20, hash_function: Optional[SimpleHashT] = None): # needs to be parameterized + def __init__( + self, quotient: int = 20, auto_expand: bool = True, hash_function: Optional[SimpleHashT] = None + ): # needs to be parameterized if quotient < 3 or quotient > 31: raise ValueError( f"Quotient filter: Invalid quotient setting; quotient must be between 3 and 31; {quotient} was provided" ) - self.__set_params(quotient, hash_function) - - def __set_params(self, quotient, hash_function): - self._q = quotient - self._r = 32 - quotient - self._size = 1 << self._q # same as 2**q - self._elements_added = 0 + self.__set_params(quotient, auto_expand, hash_function) + + def __set_params(self, quotient: int, auto_expand: bool, hash_function: Optional[SimpleHashT]): + self._q: int = quotient + self._r: int = 32 - quotient + self._size: int = 1 << self._q # same as 2**q + self._elements_added: int = 0 + self._auto_resize: bool = auto_expand self._hash_func: SimpleHashT = fnv_1a_32 if hash_function is None else hash_function # type: ignore + self._max_load_factor: float = 0.85 # ensure we use the smallest type possible to reduce memory wastage if self._r <= 8: @@ -92,7 +99,7 @@ def elements_added(self) -> int: return self._elements_added @property - def bits_per_elm(self): + def bits_per_elm(self) -> int: """int: The number of bits used per element""" return self._bits_per_elm @@ -109,6 +116,26 @@ def load_factor(self) -> float: """float: The load factor of the filter""" return self._elements_added / self._size + @property + def auto_expand(self) -> bool: + """bool: Will the quotient filter automatically expand""" + return self._auto_resize + + @auto_expand.setter + def auto_expand(self, val: bool): + """change the auto expand property""" + self._auto_resize = bool(val) + + @property + def max_load_factor(self) -> float: + """float: The maximum allowed load factor after which auto expanding should occur""" + return self._max_load_factor + + @max_load_factor.setter + def max_load_factor(self, val: float): + """set the maximum load factor""" + self._max_load_factor = float(val) + def add(self, key: KeyT) -> None: """Add key to the quotient filter @@ -125,6 +152,8 @@ def add_alt(self, _hash: int) -> None: key_quotient = _hash >> self._r key_remainder = _hash & ((1 << self._r) - 1) if self._contained_at_loc(key_quotient, key_remainder) == -1: + if self._auto_resize and self.load_factor >= self._max_load_factor: + self.resize() self._add(key_quotient, key_remainder) def check(self, key: KeyT) -> bool: @@ -193,21 +222,31 @@ def get_hashes(self) -> List[int]: list(int): The hash values stored in the quotient filter""" return list(self.iter_hashes()) - def resize(self, quotient: int) -> None: + def resize(self, quotient: Optional[int] = None) -> None: """Resize the quotient filter to use the new quotient size Args: int: The new quotient to use + Note: + If `None` is provided, the quotient filter will double in size (quotient + 1) Raises: ValueError: When the new quotient will not accommodate the elements already added""" + if quotient is None: + quotient = self._q + 1 + if self.elements_added >= (1 << quotient): raise ValueError("Unable to shrink since there will be too many elements in the quotient filter") + if quotient < 3 or quotient > 31: + raise ValueError( + f"Quotient filter: Invalid quotient setting; quotient must be between 3 and 31; {quotient} was provided" + ) + hashes = self.get_hashes() for i in range(self._size): self._filter[i] = 0 - self.__set_params(quotient, self._hash_func) + self.__set_params(quotient, self._auto_resize, self._hash_func) for _h in hashes: self.add_alt(_h) diff --git a/tests/quotientfilter_test.py b/tests/quotientfilter_test.py index 2da96cd..5b9a0cd 100644 --- a/tests/quotientfilter_test.py +++ b/tests/quotientfilter_test.py @@ -38,14 +38,16 @@ def test_qf_init(self): self.assertEqual(qf.remainder, 24) self.assertEqual(qf.elements_added, 0) self.assertEqual(qf.num_elements, 256) # 2**qf.quotient + self.assertTrue(qf.auto_expand) - qf = QuotientFilter(quotient=24) + qf = QuotientFilter(quotient=24, auto_expand=False) self.assertEqual(qf.bits_per_elm, 8) self.assertEqual(qf.quotient, 24) self.assertEqual(qf.remainder, 8) self.assertEqual(qf.elements_added, 0) self.assertEqual(qf.num_elements, 16777216) # 2**qf.quotient + self.assertFalse(qf.auto_expand) def test_qf_add_check(self): "test that the qf is able to add and check elements" @@ -94,7 +96,7 @@ def test_qf_init_errors(self): def test_retrieve_hashes(self): """test retrieving hashes back from the quotient filter""" - qf = QuotientFilter(quotient=8) + qf = QuotientFilter(quotient=8, auto_expand=False) hashes = [] for i in range(255): hashes.append(qf._hash_func(str(i), 0)) # use the private function here.. @@ -107,7 +109,7 @@ def test_retrieve_hashes(self): def test_resize(self): """test resizing the quotient filter""" - qf = QuotientFilter(quotient=8) + qf = QuotientFilter(quotient=8, auto_expand=False) for i in range(200): qf.add(str(i)) @@ -128,3 +130,23 @@ def test_resize(self): # ensure everything is still accessable for i in range(200): self.assertTrue(qf.check(str(i))) + + def test_auto_resize(self): + """test resizing the quotient filter""" + qf = QuotientFilter(quotient=8, auto_expand=True) + self.assertEqual(qf.max_load_factor, 0.85) + self.assertEqual(qf.elements_added, 0) + self.assertEqual(qf.load_factor, 0 / qf.size) + self.assertEqual(qf.quotient, 8) + self.assertEqual(qf.remainder, 24) + self.assertEqual(qf.bits_per_elm, 32) + + for i in range(220): + qf.add(str(i)) + + self.assertEqual(qf.max_load_factor, 0.85) + self.assertEqual(qf.elements_added, 220) + self.assertEqual(qf.load_factor, 220 / qf.size) + self.assertEqual(qf.quotient, 9) + self.assertEqual(qf.remainder, 23) + self.assertEqual(qf.bits_per_elm, 32)