Skip to content

Commit

Permalink
quotient filter: expand
Browse files Browse the repository at this point in the history
  • Loading branch information
barrust committed Jan 13, 2024
1 parent e9115f0 commit 8ae5eed
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 14 deletions.
61 changes: 50 additions & 11 deletions probables/quotientfilter/quotientfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class QuotientFilter:
Args:
quotient (int): The size of the quotient to use
auto_expand (bool): Automatically expand or not
hash_function (function): Hashing strategy function to use `hf(key, number)`
Returns:
QuotientFilter: The initialized filter
Expand All @@ -35,21 +36,27 @@ class QuotientFilter:
"_is_continuation",
"_is_shifted",
"_filter",
"_max_load_factor",
"_auto_resize",
)

def __init__(self, quotient: int = 20, hash_function: Optional[SimpleHashT] = None): # needs to be parameterized
def __init__(
self, quotient: int = 20, auto_expand: bool = True, hash_function: Optional[SimpleHashT] = None
): # needs to be parameterized
if quotient < 3 or quotient > 31:
raise ValueError(
f"Quotient filter: Invalid quotient setting; quotient must be between 3 and 31; {quotient} was provided"
)
self.__set_params(quotient, hash_function)

def __set_params(self, quotient, hash_function):
self._q = quotient
self._r = 32 - quotient
self._size = 1 << self._q # same as 2**q
self._elements_added = 0
self.__set_params(quotient, auto_expand, hash_function)

def __set_params(self, quotient: int, auto_expand: bool, hash_function: Optional[SimpleHashT]):
self._q: int = quotient
self._r: int = 32 - quotient
self._size: int = 1 << self._q # same as 2**q
self._elements_added: int = 0
self._auto_resize: bool = auto_expand
self._hash_func: SimpleHashT = fnv_1a_32 if hash_function is None else hash_function # type: ignore
self._max_load_factor: float = 0.85

# ensure we use the smallest type possible to reduce memory wastage
if self._r <= 8:
Expand Down Expand Up @@ -92,7 +99,7 @@ def elements_added(self) -> int:
return self._elements_added

@property
def bits_per_elm(self):
def bits_per_elm(self) -> int:
"""int: The number of bits used per element"""
return self._bits_per_elm

Expand All @@ -109,6 +116,26 @@ def load_factor(self) -> float:
"""float: The load factor of the filter"""
return self._elements_added / self._size

@property
def auto_expand(self) -> bool:
"""bool: Will the quotient filter automatically expand"""
return self._auto_resize

@auto_expand.setter
def auto_expand(self, val: bool):
"""change the auto expand property"""
self._auto_resize = bool(val)

@property
def max_load_factor(self) -> float:
"""float: The maximum allowed load factor after which auto expanding should occur"""
return self._max_load_factor

@max_load_factor.setter
def max_load_factor(self, val: float):
"""set the maximum load factor"""
self._max_load_factor = float(val)

def add(self, key: KeyT) -> None:
"""Add key to the quotient filter
Expand All @@ -125,6 +152,8 @@ def add_alt(self, _hash: int) -> None:
key_quotient = _hash >> self._r
key_remainder = _hash & ((1 << self._r) - 1)
if self._contained_at_loc(key_quotient, key_remainder) == -1:
if self._auto_resize and self.load_factor >= self._max_load_factor:
self.resize()
self._add(key_quotient, key_remainder)

def check(self, key: KeyT) -> bool:
Expand Down Expand Up @@ -193,21 +222,31 @@ def get_hashes(self) -> List[int]:
list(int): The hash values stored in the quotient filter"""
return list(self.iter_hashes())

def resize(self, quotient: int) -> None:
def resize(self, quotient: Optional[int] = None) -> None:
"""Resize the quotient filter to use the new quotient size
Args:
int: The new quotient to use
Note:
If `None` is provided, the quotient filter will double in size (quotient + 1)
Raises:
ValueError: When the new quotient will not accommodate the elements already added"""
if quotient is None:
quotient = self._q + 1

if self.elements_added >= (1 << quotient):
raise ValueError("Unable to shrink since there will be too many elements in the quotient filter")
if quotient < 3 or quotient > 31:
raise ValueError(
f"Quotient filter: Invalid quotient setting; quotient must be between 3 and 31; {quotient} was provided"
)

hashes = self.get_hashes()

for i in range(self._size):
self._filter[i] = 0

self.__set_params(quotient, self._hash_func)
self.__set_params(quotient, self._auto_resize, self._hash_func)

for _h in hashes:
self.add_alt(_h)
Expand Down
28 changes: 25 additions & 3 deletions tests/quotientfilter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ def test_qf_init(self):
self.assertEqual(qf.remainder, 24)
self.assertEqual(qf.elements_added, 0)
self.assertEqual(qf.num_elements, 256) # 2**qf.quotient
self.assertTrue(qf.auto_expand)

qf = QuotientFilter(quotient=24)
qf = QuotientFilter(quotient=24, auto_expand=False)

self.assertEqual(qf.bits_per_elm, 8)
self.assertEqual(qf.quotient, 24)
self.assertEqual(qf.remainder, 8)
self.assertEqual(qf.elements_added, 0)
self.assertEqual(qf.num_elements, 16777216) # 2**qf.quotient
self.assertFalse(qf.auto_expand)

def test_qf_add_check(self):
"test that the qf is able to add and check elements"
Expand Down Expand Up @@ -94,7 +96,7 @@ def test_qf_init_errors(self):

def test_retrieve_hashes(self):
"""test retrieving hashes back from the quotient filter"""
qf = QuotientFilter(quotient=8)
qf = QuotientFilter(quotient=8, auto_expand=False)
hashes = []
for i in range(255):
hashes.append(qf._hash_func(str(i), 0)) # use the private function here..
Expand All @@ -107,7 +109,7 @@ def test_retrieve_hashes(self):

def test_resize(self):
"""test resizing the quotient filter"""
qf = QuotientFilter(quotient=8)
qf = QuotientFilter(quotient=8, auto_expand=False)
for i in range(200):
qf.add(str(i))

Expand All @@ -128,3 +130,23 @@ def test_resize(self):
# ensure everything is still accessable
for i in range(200):
self.assertTrue(qf.check(str(i)))

def test_auto_resize(self):
"""test resizing the quotient filter"""
qf = QuotientFilter(quotient=8, auto_expand=True)
self.assertEqual(qf.max_load_factor, 0.85)
self.assertEqual(qf.elements_added, 0)
self.assertEqual(qf.load_factor, 0 / qf.size)
self.assertEqual(qf.quotient, 8)
self.assertEqual(qf.remainder, 24)
self.assertEqual(qf.bits_per_elm, 32)

for i in range(220):
qf.add(str(i))

self.assertEqual(qf.max_load_factor, 0.85)
self.assertEqual(qf.elements_added, 220)
self.assertEqual(qf.load_factor, 220 / qf.size)
self.assertEqual(qf.quotient, 9)
self.assertEqual(qf.remainder, 23)
self.assertEqual(qf.bits_per_elm, 32)

0 comments on commit 8ae5eed

Please sign in to comment.