Skip to content

Commit

Permalink
add and and avoid constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Sep 28, 2024
1 parent 270ecc5 commit df70e73
Showing 1 changed file with 72 additions and 6 deletions.
78 changes: 72 additions & 6 deletions python/text_utils/constraints.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import reduce
import numpy as np

from text_utils._internal import grammar
Expand Down Expand Up @@ -64,9 +65,8 @@ class ContinuationConstraint(Constraint):
def __init__(
self,
cont_index: continuations.MmapContinuationIndex,
prefix: bytes | None = None
):
self.prefix = prefix or bytes()
self.prefix = bytes()
self.indices, self.value = cont_index.get(self.prefix)
self.cont_index = cont_index

Expand All @@ -85,10 +85,76 @@ def is_match(self) -> bool:
return self.value is not None

def clone(self) -> 'ContinuationConstraint':
return ContinuationConstraint(
self.cont_index,
self.prefix
)
const = ContinuationConstraint(self.cont_index)
const.reset(self.prefix)
return const

def get_value(self) -> str | None:
return self.value


def array_intersection(*arrays: np.ndarray) -> np.ndarray:
"""
Returns the intersection of multiple arrays.
"""
assert len(arrays) > 0, "at least one array required"
return reduce(np.intersect1d, arrays)


class AndConstraint(Constraint):
def __init__(self, constraints: list[Constraint]):
assert len(constraints) > 0, "at least one constraint required"
self.constraints = constraints

def get(self) -> np.ndarray:
return array_intersection(*(c.get() for c in self.constraints))

def reset(self, input: bytes | None = None) -> None:
for c in self.constraints:
c.reset(input)

def next(self, index: int) -> None:
for c in self.constraints:
c.next(index)

def is_match(self) -> bool:
return all(c.is_match() for c in self.constraints)

def clone(self) -> 'AndConstraint':
return AndConstraint([c.clone() for c in self.constraints])


class AvoidConstraint(Constraint):
def __init__(
self,
avoid: list[bytes],
continuations: list[bytes],
eos_token_id: int
):
self.avoid = avoid
self.value = bytes()
self.continuations = continuations
self.eos_token_id = eos_token_id
self.all = np.arange(len(continuations))
self.all_but_eos = np.delete(self.all, self.eos_token_id)

def get(self) -> np.ndarray:
return self.all if self.is_match() else self.all_but_eos

def reset(self, input: bytes | None = None) -> None:
self.value = input or bytes()

def next(self, index: int) -> None:
self.value += self.continuations[index]

def is_match(self) -> bool:
return all(avoid != self.value for avoid in self.avoid)

def clone(self) -> 'AvoidConstraint':
const = AvoidConstraint(
self.avoid,
self.continuations,
self.eos_token_id
)
const.reset(self.value)
return const

0 comments on commit df70e73

Please sign in to comment.