diff --git a/python/outlines_core/fsm/fsm.py b/python/outlines_core/fsm/fsm.py deleted file mode 100644 index 4daf3c86..00000000 --- a/python/outlines_core/fsm/fsm.py +++ /dev/null @@ -1,47 +0,0 @@ -import warnings -from typing import TYPE_CHECKING, Iterable, NewType, Optional - -from outlines_core.fsm.guide import RegexGuide, StopAtEOSGuide - -if TYPE_CHECKING: - from outlines_core.models.tokenizer import Tokenizer - -FSMState = NewType("FSMState", int) - - -class StopAtEosFSM(StopAtEOSGuide): - """FSM to generate text until EOS has been generated.""" - - def __init__(self, tokenizer: "Tokenizer"): - warnings.warn( - UserWarning( - "The `StopAtTokenFSM` interface is deprecated and will be removed on 2024-06-01. Please use `StopAtEOSGuide` instead." - ) - ) - super().__init__(tokenizer) - - def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: - next_instruction = self.get_next_instruction(state) - return next_instruction.tokens - - def next_state(self, state: FSMState, token_id: int) -> FSMState: - return FSMState(self.get_next_state(state, token_id)) - - -class RegexFSM(RegexGuide): - """FSM to generate text that is in the language of a regular expression.""" - - def __init__(self, regex_string: str, tokenizer): - warnings.warn( - UserWarning( - "The `RegexFSM` interface is deprecated and will be removed on 2024-06-01. Please use `RegexGuide` instead." - ) - ) - super().__init__(regex_string, tokenizer) - - def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: - next_instruction = self.get_next_instruction(state) - return next_instruction.tokens - - def next_state(self, state: FSMState, token_id: int) -> FSMState: - return FSMState(self.get_next_state(state, token_id)) diff --git a/tests/fsm/test_fsm.py b/tests/fsm/test_fsm.py deleted file mode 100644 index aeb7060c..00000000 --- a/tests/fsm/test_fsm.py +++ /dev/null @@ -1,91 +0,0 @@ -import pytest -from outlines_core.fsm.fsm import RegexFSM, StopAtEosFSM - - -def assert_expected_tensor_ids(tensor, ids): - assert len(tensor) == len(ids) - norm_tensor = sorted(map(int, tensor)) - norm_ids = sorted(map(int, tensor)) - assert norm_tensor == norm_ids, (norm_tensor, norm_ids) - - -def test_stop_at_eos(): - class MockTokenizer: - vocabulary = {"a": 1, "eos": 2} - eos_token_id = 2 - - with pytest.warns(UserWarning): - fsm = StopAtEosFSM(MockTokenizer()) - - assert fsm.allowed_token_ids(fsm.start_state) is None - assert fsm.allowed_token_ids(fsm.final_state) == [2] - assert fsm.next_state(fsm.start_state, 2) == fsm.final_state - assert fsm.next_state(fsm.start_state, 1) == fsm.start_state - assert fsm.is_final_state(fsm.start_state) is False - assert fsm.is_final_state(fsm.final_state) is True - - -def test_regex_vocabulary_error(): - class MockTokenizer: - vocabulary = {"a": 1} - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - regex_str = "[1-9]" - - with pytest.raises(ValueError, match="The vocabulary"): - RegexFSM(regex_str, MockTokenizer()) - - -def test_regex(): - class MockTokenizer: - vocabulary = {"1": 1, "a": 2, "eos": 3} - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - regex_str = "[1-9]" - tokenizer = MockTokenizer() - - with pytest.warns(UserWarning): - fsm = RegexFSM(regex_str, tokenizer) - - assert fsm.states_to_token_maps == {0: {1: 1}} - assert_expected_tensor_ids(fsm.allowed_token_ids(state=0), [1]) - assert fsm.next_state(state=0, token_id=1) == 1 - assert fsm.next_state(state=0, token_id=tokenizer.eos_token_id) == -1 - - assert fsm.is_final_state(0) is False - - for state in fsm.final_states: - assert fsm.is_final_state(state) is True - - -def test_regex_final_state(): - """Make sure that the FSM stays in the final state as we keep generating""" - - class MockTokenizer: - vocabulary = {"`": 101, ".": 102, "\n": 103, "eos": 104} - special_tokens = {"eos"} - eos_token_id = 104 - - def convert_token_to_string(self, token): - return token - - regex_str = r"`\n(\.\n)?`\n" - tokenizer = MockTokenizer() - - with pytest.warns(UserWarning): - fsm = RegexFSM(regex_str, tokenizer) - - state = fsm.next_state(state=4, token_id=103) - assert state == 5 - assert fsm.is_final_state(state) - - state = fsm.next_state(state=5, token_id=103) - assert fsm.is_final_state(state)