From df70e737010e64acdbe473eed5e8f4823b887446 Mon Sep 17 00:00:00 2001 From: Sebastian Walter Date: Sat, 28 Sep 2024 14:24:16 +0200 Subject: [PATCH] add and and avoid constraint --- python/text_utils/constraints.py | 78 +++++++++++++++++++++++++++++--- 1 file changed, 72 insertions(+), 6 deletions(-) diff --git a/python/text_utils/constraints.py b/python/text_utils/constraints.py index ad06156..631c67e 100644 --- a/python/text_utils/constraints.py +++ b/python/text_utils/constraints.py @@ -1,3 +1,4 @@ +from functools import reduce import numpy as np from text_utils._internal import grammar @@ -64,9 +65,8 @@ class ContinuationConstraint(Constraint): def __init__( self, cont_index: continuations.MmapContinuationIndex, - prefix: bytes | None = None ): - self.prefix = prefix or bytes() + self.prefix = bytes() self.indices, self.value = cont_index.get(self.prefix) self.cont_index = cont_index @@ -85,10 +85,76 @@ def is_match(self) -> bool: return self.value is not None def clone(self) -> 'ContinuationConstraint': - return ContinuationConstraint( - self.cont_index, - self.prefix - ) + const = ContinuationConstraint(self.cont_index) + const.reset(self.prefix) + return const def get_value(self) -> str | None: return self.value + + +def array_intersection(*arrays: np.ndarray) -> np.ndarray: + """ + Returns the intersection of multiple arrays. + """ + assert len(arrays) > 0, "at least one array required" + return reduce(np.intersect1d, arrays) + + +class AndConstraint(Constraint): + def __init__(self, constraints: list[Constraint]): + assert len(constraints) > 0, "at least one constraint required" + self.constraints = constraints + + def get(self) -> np.ndarray: + return array_intersection(*(c.get() for c in self.constraints)) + + def reset(self, input: bytes | None = None) -> None: + for c in self.constraints: + c.reset(input) + + def next(self, index: int) -> None: + for c in self.constraints: + c.next(index) + + def is_match(self) -> bool: + return all(c.is_match() for c in self.constraints) + + def clone(self) -> 'AndConstraint': + return AndConstraint([c.clone() for c in self.constraints]) + + +class AvoidConstraint(Constraint): + def __init__( + self, + avoid: list[bytes], + continuations: list[bytes], + eos_token_id: int + ): + self.avoid = avoid + self.value = bytes() + self.continuations = continuations + self.eos_token_id = eos_token_id + self.all = np.arange(len(continuations)) + self.all_but_eos = np.delete(self.all, self.eos_token_id) + + def get(self) -> np.ndarray: + return self.all if self.is_match() else self.all_but_eos + + def reset(self, input: bytes | None = None) -> None: + self.value = input or bytes() + + def next(self, index: int) -> None: + self.value += self.continuations[index] + + def is_match(self) -> bool: + return all(avoid != self.value for avoid in self.avoid) + + def clone(self) -> 'AvoidConstraint': + const = AvoidConstraint( + self.avoid, + self.continuations, + self.eos_token_id + ) + const.reset(self.value) + return const