From 4a3bb7e076212ee748ae1abe27f0f514ee507ecc Mon Sep 17 00:00:00 2001 From: erikkaum Date: Thu, 22 Aug 2024 14:28:59 +0200 Subject: [PATCH] with maturin & incorporate latest changes --- .github/workflows/outlines_core_python_ci.yml | 169 --- .github/workflows/python_ci.yml | 323 +++++ .github/workflows/tests.yml | 102 -- Cargo.lock | 185 +-- Cargo.toml | 2 +- bindings/python/.gitignore | 72 ++ bindings/python/Cargo.toml | 19 +- bindings/python/Manifest.in | 9 - .../{benchmarks/__init__.py => README.md} | 0 bindings/python/benchmarks/asv.conf.json | 21 - .../python/benchmarks/bench_json_schema.py | 77 -- .../python/benchmarks/bench_numba_compile.py | 31 - .../python/benchmarks/bench_regex_guide.py | 37 - bindings/python/benchmarks/common.py | 8 - .../python}/outlines_core/__init__.py | 3 +- .../{src => }/outlines_core/fsm/__init__.py | 0 .../python}/outlines_core/fsm/fsm.py | 2 +- .../{src => }/outlines_core/fsm/guide.py | 2 +- .../outlines_core/fsm/json_schema.py | 2 +- .../python}/outlines_core/fsm/regex.py | 5 +- .../python}/outlines_core/fsm/types.py | 2 +- .../outlines_core/integrations/utils.py | 2 +- .../python}/outlines_core/models/__init__.py | 2 +- .../python}/outlines_core/models/tokenizer.py | 2 +- .../outlines_core/models/transformers.py | 2 +- .../python/{src => }/outlines_core/py.typed | 0 bindings/python/pyproject.toml | 127 +- bindings/python/rust/lib.rs | 12 - bindings/python/setup.cfg | 11 - bindings/python/src/{regex.rs => lib.rs} | 238 +--- bindings/python/src/outlines_core/__init__.py | 5 - bindings/python/src/outlines_core/fsm/fsm.py | 47 - .../python/src/outlines_core/fsm/regex.py | 0 .../python/src/outlines_core/fsm/types.py | 81 -- .../src/outlines_core/integrations/utils.py | 41 - .../src/outlines_core/models/__init__.py | 13 - .../src/outlines_core/models/tokenizer.py | 31 - .../src/outlines_core/models/transformers.py | 474 -------- bindings/python/tests/__init__.py | 0 bindings/python/tests/fsm/partial_python.lark | 314 ----- bindings/python/tests/fsm/test_fsm.py | 2 +- bindings/python/tests/fsm/test_regex.py | 373 ++---- justfile | 9 +- outlines-core/Cargo.lock | 7 - outlines-core/src/lib.rs | 5 +- outlines-core/src/regex.rs | 143 +++ pyproject.toml | 141 --- python/outlines_core/fsm/__init__.py | 0 python/outlines_core/fsm/guide.py | 295 ----- python/outlines_core/fsm/json_schema.py | 519 -------- python/outlines_core/py.typed | 0 setup.py | 20 - src/lib.rs | 23 - tests/fsm/test_fsm.py | 91 -- tests/fsm/test_guide.py | 189 --- tests/fsm/test_json_schema.py | 1040 ----------------- tests/fsm/test_regex.py | 524 --------- tests/fsm/test_types.py | 28 - tests/models/test_tokenizer.py | 7 - tests/models/test_transformers.py | 116 -- 60 files changed, 724 insertions(+), 5281 deletions(-) delete mode 100644 .github/workflows/outlines_core_python_ci.yml create mode 100644 .github/workflows/python_ci.yml delete mode 100644 .github/workflows/tests.yml create mode 100644 bindings/python/.gitignore delete mode 100644 bindings/python/Manifest.in rename bindings/python/{benchmarks/__init__.py => README.md} (100%) delete mode 100644 bindings/python/benchmarks/asv.conf.json delete mode 100644 bindings/python/benchmarks/bench_json_schema.py delete mode 100644 bindings/python/benchmarks/bench_numba_compile.py delete mode 100644 bindings/python/benchmarks/bench_regex_guide.py delete mode 100644 bindings/python/benchmarks/common.py rename {python => bindings/python}/outlines_core/__init__.py (53%) rename bindings/python/{src => }/outlines_core/fsm/__init__.py (100%) rename {python => bindings/python}/outlines_core/fsm/fsm.py (96%) rename bindings/python/{src => }/outlines_core/fsm/guide.py (99%) rename bindings/python/{src => }/outlines_core/fsm/json_schema.py (99%) rename {python => bindings/python}/outlines_core/fsm/regex.py (99%) rename {python => bindings/python}/outlines_core/fsm/types.py (99%) rename {python => bindings/python}/outlines_core/integrations/utils.py (98%) rename {python => bindings/python}/outlines_core/models/__init__.py (90%) rename {python => bindings/python}/outlines_core/models/tokenizer.py (98%) rename {python => bindings/python}/outlines_core/models/transformers.py (99%) rename bindings/python/{src => }/outlines_core/py.typed (100%) delete mode 100644 bindings/python/rust/lib.rs delete mode 100644 bindings/python/setup.cfg rename bindings/python/src/{regex.rs => lib.rs} (50%) delete mode 100644 bindings/python/src/outlines_core/__init__.py delete mode 100644 bindings/python/src/outlines_core/fsm/fsm.py delete mode 100644 bindings/python/src/outlines_core/fsm/regex.py delete mode 100644 bindings/python/src/outlines_core/fsm/types.py delete mode 100644 bindings/python/src/outlines_core/integrations/utils.py delete mode 100644 bindings/python/src/outlines_core/models/__init__.py delete mode 100644 bindings/python/src/outlines_core/models/tokenizer.py delete mode 100644 bindings/python/src/outlines_core/models/transformers.py delete mode 100644 bindings/python/tests/__init__.py delete mode 100644 bindings/python/tests/fsm/partial_python.lark delete mode 100644 outlines-core/Cargo.lock create mode 100644 outlines-core/src/regex.rs delete mode 100644 pyproject.toml delete mode 100644 python/outlines_core/fsm/__init__.py delete mode 100644 python/outlines_core/fsm/guide.py delete mode 100644 python/outlines_core/fsm/json_schema.py delete mode 100644 python/outlines_core/py.typed delete mode 100644 setup.py delete mode 100644 src/lib.rs delete mode 100644 tests/fsm/test_fsm.py delete mode 100644 tests/fsm/test_guide.py delete mode 100644 tests/fsm/test_json_schema.py delete mode 100644 tests/fsm/test_regex.py delete mode 100644 tests/fsm/test_types.py delete mode 100644 tests/models/test_tokenizer.py delete mode 100644 tests/models/test_transformers.py diff --git a/.github/workflows/outlines_core_python_ci.yml b/.github/workflows/outlines_core_python_ci.yml deleted file mode 100644 index 47b02406..00000000 --- a/.github/workflows/outlines_core_python_ci.yml +++ /dev/null @@ -1,169 +0,0 @@ -# This file is autogenerated by maturin v1.7.0 -# To update, run -# -# maturin generate-ci github -# -# name: CI - -# on: -# push: -# branches: -# - main -# - master -# tags: -# - '*' -# pull_request: -# workflow_dispatch: - -# permissions: -# contents: read - -# jobs: -# linux: -# runs-on: ${{ matrix.platform.runner }} -# strategy: -# matrix: -# platform: -# - runner: ubuntu-latest -# target: x86_64 -# - runner: ubuntu-latest -# target: x86 -# - runner: ubuntu-latest -# target: aarch64 -# - runner: ubuntu-latest -# target: armv7 -# - runner: ubuntu-latest -# target: s390x -# - runner: ubuntu-latest -# target: ppc64le -# steps: -# - uses: actions/checkout@v4 -# - uses: actions/setup-python@v5 -# with: -# python-version: 3.x -# - name: Build wheels -# uses: PyO3/maturin-action@v1 -# with: -# target: ${{ matrix.platform.target }} -# args: --release --out dist --find-interpreter -# sccache: 'true' -# manylinux: auto -# - name: Upload wheels -# uses: actions/upload-artifact@v4 -# with: -# name: wheels-linux-${{ matrix.platform.target }} -# path: dist - -# musllinux: -# runs-on: ${{ matrix.platform.runner }} -# strategy: -# matrix: -# platform: -# - runner: ubuntu-latest -# target: x86_64 -# - runner: ubuntu-latest -# target: x86 -# - runner: ubuntu-latest -# target: aarch64 -# - runner: ubuntu-latest -# target: armv7 -# steps: -# - uses: actions/checkout@v4 -# - uses: actions/setup-python@v5 -# with: -# python-version: 3.x -# - name: Build wheels -# uses: PyO3/maturin-action@v1 -# with: -# target: ${{ matrix.platform.target }} -# args: --release --out dist --find-interpreter -# sccache: 'true' -# manylinux: musllinux_1_2 -# - name: Upload wheels -# uses: actions/upload-artifact@v4 -# with: -# name: wheels-musllinux-${{ matrix.platform.target }} -# path: dist - -# windows: -# runs-on: ${{ matrix.platform.runner }} -# strategy: -# matrix: -# platform: -# - runner: windows-latest -# target: x64 -# - runner: windows-latest -# target: x86 -# steps: -# - uses: actions/checkout@v4 -# - uses: actions/setup-python@v5 -# with: -# python-version: 3.x -# architecture: ${{ matrix.platform.target }} -# - name: Build wheels -# uses: PyO3/maturin-action@v1 -# with: -# target: ${{ matrix.platform.target }} -# args: --release --out dist --find-interpreter -# sccache: 'true' -# - name: Upload wheels -# uses: actions/upload-artifact@v4 -# with: -# name: wheels-windows-${{ matrix.platform.target }} -# path: dist - -# macos: -# runs-on: ${{ matrix.platform.runner }} -# strategy: -# matrix: -# platform: -# - runner: macos-12 -# target: x86_64 -# - runner: macos-14 -# target: aarch64 -# steps: -# - uses: actions/checkout@v4 -# - uses: actions/setup-python@v5 -# with: -# python-version: 3.x -# - name: Build wheels -# uses: PyO3/maturin-action@v1 -# with: -# target: ${{ matrix.platform.target }} -# args: --release --out dist --find-interpreter -# sccache: 'true' -# - name: Upload wheels -# uses: actions/upload-artifact@v4 -# with: -# name: wheels-macos-${{ matrix.platform.target }} -# path: dist - -# sdist: -# runs-on: ubuntu-latest -# steps: -# - uses: actions/checkout@v4 -# - name: Build sdist -# uses: PyO3/maturin-action@v1 -# with: -# command: sdist -# args: --out dist -# - name: Upload sdist -# uses: actions/upload-artifact@v4 -# with: -# name: wheels-sdist -# path: dist - -# release: -# name: Release -# runs-on: ubuntu-latest -# if: "startsWith(github.ref, 'refs/tags/')" -# needs: [linux, musllinux, windows, macos, sdist] -# steps: -# - uses: actions/download-artifact@v4 -# - name: Publish to PyPI -# uses: PyO3/maturin-action@v1 -# env: -# MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} -# with: -# command: upload -# args: --non-interactive --skip-existing wheels-*/* diff --git a/.github/workflows/python_ci.yml b/.github/workflows/python_ci.yml new file mode 100644 index 00000000..c64a793b --- /dev/null +++ b/.github/workflows/python_ci.yml @@ -0,0 +1,323 @@ +name: CI & Tests + +on: + pull_request: + branches: [main] + push: + branches: [main] + +permissions: + contents: read + +jobs: + style: + name: Check the code style + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: "3.10" + - uses: pre-commit/action@v3.0.0 + + tests: + name: Run the tests + runs-on: ubuntu-latest + defaults: + run: + working-directory: bindings/python + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Set up test environment + run: | + python3 -m venv .venv + source .venv/bin/activate + pip install outlines-core --find-links dist --force-reinstall + pip install pytest + - name: Run tests + run: | + pytest --cov=src/outlines_core + - name: Upload coverage data + uses: actions/upload-artifact@v3 + with: + name: coverage-data + path: .coverage* + if-no-files-found: ignore + + coverage: + name: Combine & check coverage. + needs: tests + runs-on: ubuntu-latest + defaults: + run: + working-directory: bindings/python + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - uses: actions/setup-python@v4 + with: + cache: pip + python-version: "3.11" + + - name: Set up environment + run: | + pip install --upgrade "coverage[toml]>=5.1" diff-cover + + - uses: actions/download-artifact@v3 + with: + name: coverage-data + + - name: Fetch master for coverage diff + run: | + git fetch --no-tags --prune origin main + + - name: Combine coverage & fail if it's <100%. + run: | + # Combine coverage files (not needed now, but maybe later) + # python -m coverage combine + + # Produce an html report with absolute coverage information + cd bindgins/python && python -m coverage html --skip-covered --skip-empty + + # Report relative coverage and write to the workflow's summary + cd bindings/python && python -m coverage xml + diff-cover coverage.xml --markdown-report=coverage.md --fail-under=100 || (cat coverage.md >> $GITHUB_STEP_SUMMARY && exit 1) + + - name: Upload HTML report if check failed. + uses: actions/upload-artifact@v3 + with: + name: html-report + path: htmlcov + if: ${{ failure() }} + + linux: + runs-on: ${{ matrix.platform.runner }} + strategy: + matrix: + platform: + - runner: ubuntu-latest + target: x86_64 + - runner: ubuntu-latest + target: x86 + - runner: ubuntu-latest + target: aarch64 + - runner: ubuntu-latest + target: armv7 + - runner: ubuntu-latest + target: s390x + - runner: ubuntu-latest + target: ppc64le + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.x + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.platform.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + manylinux: auto + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-linux-${{ matrix.platform.target }} + path: dist + - name: pytest + if: ${{ startsWith(matrix.platform.target, 'x86_64') }} + shell: bash + run: | + set -e + python3 -m venv .venv + source .venv/bin/activate + pip install outlines-core --find-links dist --force-reinstall + pip install pytest + pytest + - name: pytest + if: ${{ !startsWith(matrix.platform.target, 'x86') && matrix.platform.target != 'ppc64' }} + uses: uraimo/run-on-arch-action@v2 + with: + arch: ${{ matrix.platform.target }} + distro: ubuntu22.04 + githubToken: ${{ github.token }} + install: | + apt-get update + apt-get install -y --no-install-recommends python3 python3-pip + pip3 install -U pip pytest + run: | + set -e + pip3 install outlines-core --find-links dist --force-reinstall + pytest + + musllinux: + runs-on: ${{ matrix.platform.runner }} + strategy: + matrix: + platform: + - runner: ubuntu-latest + target: x86_64 + - runner: ubuntu-latest + target: x86 + - runner: ubuntu-latest + target: aarch64 + - runner: ubuntu-latest + target: armv7 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.x + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.platform.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + manylinux: musllinux_1_2 + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-musllinux-${{ matrix.platform.target }} + path: dist + - name: pytest + if: ${{ startsWith(matrix.platform.target, 'x86_64') }} + uses: addnab/docker-run-action@v3 + with: + image: alpine:latest + options: -v ${{ github.workspace }}:/io -w /io + run: | + set -e + apk add py3-pip py3-virtualenv + python3 -m virtualenv .venv + source .venv/bin/activate + pip install outlines-core --no-index --find-links dist --force-reinstall + pip install pytest + pytest + - name: pytest + if: ${{ !startsWith(matrix.platform.target, 'x86') }} + uses: uraimo/run-on-arch-action@v2 + with: + arch: ${{ matrix.platform.target }} + distro: alpine_latest + githubToken: ${{ github.token }} + install: | + apk add py3-virtualenv + run: | + set -e + python3 -m virtualenv .venv + source .venv/bin/activate + pip install pytest + pip install outlines-core --find-links dist --force-reinstall + pytest + + windows: + runs-on: ${{ matrix.platform.runner }} + strategy: + matrix: + platform: + - runner: windows-latest + target: x64 + - runner: windows-latest + target: x86 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.x + architecture: ${{ matrix.platform.target }} + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.platform.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-windows-${{ matrix.platform.target }} + path: dist + - name: pytest + if: ${{ !startsWith(matrix.platform.target, 'aarch64') }} + shell: bash + run: | + set -e + python3 -m venv .venv + source .venv/Scripts/activate + pip install outlines-core --find-links dist --force-reinstall + pip install pytest + pytest + + macos: + runs-on: ${{ matrix.platform.runner }} + strategy: + matrix: + platform: + - runner: macos-12 + target: x86_64 + - runner: macos-14 + target: aarch64 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.x + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.platform.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-macos-${{ matrix.platform.target }} + path: dist + - name: pytest + run: | + set -e + python3 -m venv .venv + source .venv/bin/activate + pip install outlines-core --find-links dist --force-reinstall + pip install pytest + pytest + + sdist: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Build sdist + uses: PyO3/maturin-action@v1 + with: + command: sdist + args: --out dist + - name: Upload sdist + uses: actions/upload-artifact@v4 + with: + name: wheels-sdist + path: dist + +# release: +# name: Release +# runs-on: ubuntu-latest +# if: "startsWith(github.ref, 'refs/tags/')" +# needs: [linux, musllinux, windows, macos, sdist] +# steps: +# - uses: actions/download-artifact@v4 +# - name: Publish to PyPI +# uses: PyO3/maturin-action@v1 +# env: +# PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} +# with: +# command: upload +# args: --non-interactive --skip-existing wheels-*/* \ No newline at end of file diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml deleted file mode 100644 index 821a17fa..00000000 --- a/.github/workflows/tests.yml +++ /dev/null @@ -1,102 +0,0 @@ -name: Tests - -on: - pull_request: - branches: [main] - push: - branches: [main] - -jobs: - style: - name: Check the code style - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: "3.10" - - uses: pre-commit/action@v3.0.0 - - tests: - name: Run the tests - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.10"] - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Set up test environment - run: | - python -m pip install --upgrade pip - cd bindings/python && pip install '.[test]' - - name: Run tests - run: | -<<<<<<< HEAD - pytest --cov=outlines_core -======= - cd bindings/python && pytest --cov=src/outlines_core ->>>>>>> 39def93 (changes paths for tests) - - name: Upload coverage data - uses: actions/upload-artifact@v3 - with: - name: coverage-data - path: .coverage* - if-no-files-found: ignore - - coverage: - name: Combine & check coverage. - needs: tests - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - - uses: actions/setup-python@v4 - with: - cache: pip - python-version: "3.11" - - - name: Set up environment - run: | - cd bindgins/python && pip install --upgrade "coverage[toml]>=5.1" diff-cover - - - uses: actions/download-artifact@v3 - with: - name: coverage-data - - - name: Fetch master for coverage diff - run: | - git fetch --no-tags --prune origin main - - - name: Combine coverage & fail if it's <100%. - run: | - # Combine coverage files (not needed now, but maybe later) - # python -m coverage combine - - # Produce an html report with absolute coverage information - cd bindgins/python && python -m coverage html --skip-covered --skip-empty - - # Report relative coverage and write to the workflow's summary - cd bindings/python && python -m coverage xml - diff-cover coverage.xml --markdown-report=coverage.md --fail-under=100 || (cat coverage.md >> $GITHUB_STEP_SUMMARY && exit 1) - - - name: Upload HTML report if check failed. - uses: actions/upload-artifact@v3 - with: - name: html-report - path: htmlcov - if: ${{ failure() }} - - build-wheel: - name: Build Wheel and Test SDist - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Build SDist and Wheel - run: ./.github/scripts/build_sdist_and_wheel.sh diff --git a/Cargo.lock b/Cargo.lock index 52fc1582..baa0c623 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3,176 +3,13 @@ version = 3 [[package]] -name = "autocfg" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "heck" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" - -[[package]] -name = "indoc" -version = "2.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" - -[[package]] -name = "libc" -version = "0.2.158" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" - -[[package]] -name = "memoffset" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" -dependencies = [ - "autocfg", -] - -[[package]] -name = "once_cell" -version = "1.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" - -[[package]] -name = "outlines-core-rs" -version = "0.1.0" +name = "_outlines_core_rs" +version = "0.1.0-dev.0" dependencies = [ + "outlines-core", "pyo3", ] -[[package]] -name = "portable-atomic" -version = "1.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" - -[[package]] -name = "proc-macro2" -version = "1.0.86" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "pyo3" -version = "0.22.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433" -dependencies = [ - "cfg-if", - "indoc", - "libc", - "memoffset", - "once_cell", - "portable-atomic", - "pyo3-build-config", - "pyo3-ffi", - "pyo3-macros", - "unindent", -] - -[[package]] -name = "pyo3-build-config" -version = "0.22.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8" -dependencies = [ - "once_cell", - "target-lexicon", -] - -[[package]] -name = "pyo3-ffi" -version = "0.22.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6" -dependencies = [ - "libc", - "pyo3-build-config", -] - -[[package]] -name = "pyo3-macros" -version = "0.22.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206" -dependencies = [ - "proc-macro2", - "pyo3-macros-backend", - "quote", - "syn", -] - -[[package]] -name = "pyo3-macros-backend" -version = "0.22.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372" -dependencies = [ - "heck", - "proc-macro2", - "pyo3-build-config", - "quote", - "syn", -] - -[[package]] -name = "quote" -version = "1.0.36" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "syn" -version = "2.0.75" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6af063034fc1935ede7be0122941bafa9bacb949334d090b77ca98b5817c7d9" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "target-lexicon" -version = "0.12.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" - -[[package]] -name = "unicode-ident" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - -[[package]] -name = "unindent" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - [[package]] name = "autocfg" version = "1.3.0" @@ -199,9 +36,9 @@ checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" [[package]] name = "libc" -version = "0.2.155" +version = "0.2.158" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" [[package]] name = "memoffset" @@ -300,14 +137,6 @@ dependencies = [ "syn", ] -[[package]] -name = "python-bindings" -version = "0.1.0" -dependencies = [ - "outlines-core", - "pyo3", -] - [[package]] name = "quote" version = "1.0.36" @@ -319,9 +148,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.74" +version = "2.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fceb41e3d546d0bd83421d3409b1460cc7444cd389341a4c880fe7a042cb3d7" +checksum = "f6af063034fc1935ede7be0122941bafa9bacb949334d090b77ca98b5817c7d9" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 34af0dec..2a0981a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,4 +4,4 @@ "bindings/python", ] - resolver = "2" + resolver = "2" \ No newline at end of file diff --git a/bindings/python/.gitignore b/bindings/python/.gitignore new file mode 100644 index 00000000..c8f04429 --- /dev/null +++ b/bindings/python/.gitignore @@ -0,0 +1,72 @@ +/target + +# Byte-compiled / optimized / DLL files +__pycache__/ +.pytest_cache/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +.venv/ +env/ +bin/ +build/ +develop-eggs/ +dist/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +include/ +man/ +venv/ +*.egg-info/ +.installed.cfg +*.egg + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt +pip-selfcheck.json + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# Rope +.ropeproject + +# Django stuff: +*.log +*.pot + +.DS_Store + +# Sphinx documentation +docs/_build/ + +# PyCharm +.idea/ + +# VSCode +.vscode/ + +# Pyenv +.python-version diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index a97f9dda..f7f2cc3c 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -1,18 +1,13 @@ [package] -name = "python-bindings" -version = "0.1.0" +name = "_outlines_core_rs" +version = "0.1.0-dev.0" edition = "2021" -[dependencies] -pyo3 = "0.22.0" -[profile.release-lto] -inherits = "release" -lto = true - +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] -name = "_lib" +name = "_outlines_core_rs" crate-type = ["cdylib"] -path = "rust/lib.rs" -[dependencies.outlines-core] -path = "../../outlines-core" +[dependencies] +pyo3 = "0.22.0" +outlines-core = { path = "../../outlines-core" } \ No newline at end of file diff --git a/bindings/python/Manifest.in b/bindings/python/Manifest.in deleted file mode 100644 index 02bee506..00000000 --- a/bindings/python/Manifest.in +++ /dev/null @@ -1,9 +0,0 @@ -graft src -graft rust -include Cargo.toml - -global-exclude */__pycache__/* -global-exclude *.pyc - -recursive-include outlines-core-lib * -recursive-exclude outlines-core/target * diff --git a/bindings/python/benchmarks/__init__.py b/bindings/python/README.md similarity index 100% rename from bindings/python/benchmarks/__init__.py rename to bindings/python/README.md diff --git a/bindings/python/benchmarks/asv.conf.json b/bindings/python/benchmarks/asv.conf.json deleted file mode 100644 index 3959e2f0..00000000 --- a/bindings/python/benchmarks/asv.conf.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "version": 1, - "project": "Outlines-core", - "project_url": "https://outlines-dev.github.io/outlines-core/", - "repo": "..", - "branches": [ - "HEAD" - ], - "build_command": [ - "pip install setuptools_rust", - "python -mpip install .[test]", - "PIP_NO_BUILD_ISOLATION=false python -mpip wheel --no-deps --no-index -w {build_cache_dir} {build_dir}", - ], - "environment_type": "virtualenv", - "show_commit_url": "https://github.com/outlines-dev/outlines-core/commit/", - "benchmark_dir": ".", - "env_dir": "env", - "results_dir": "results", - "html_dir": "html", - "build_cache_size": 8 -} diff --git a/bindings/python/benchmarks/bench_json_schema.py b/bindings/python/benchmarks/bench_json_schema.py deleted file mode 100644 index 47578cd3..00000000 --- a/bindings/python/benchmarks/bench_json_schema.py +++ /dev/null @@ -1,77 +0,0 @@ -from outlines_core.fsm.guide import RegexGuide -from outlines_core.fsm.json_schema import build_regex_from_schema - -from .common import setup_tokenizer # noqa: E402 - -simple_schema = """{ - "$defs": { - "Armor": { - "enum": ["leather", "chainmail", "plate"], - "title": "Armor", - "type": "string" - } - }, - "properties": { - "name": {"maxLength": 10, "title": "Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "armor": {"$ref": "#/$defs/Armor"}, - "strength": {"title": "Strength", "type": "integer"}\ - }, - "required": ["name", "age", "armor", "strength"], - "title": "Character", - "type": "object" - }""" - - -complex_schema = """{ - "$schema": "http://json-schema.org/draft-04/schema#", - "title": "Schema for a recording", - "type": "object", - "definitions": { - "artist": { - "type": "object", - "properties": { - "id": {"type": "number"}, - "name": {"type": "string"}, - "functions": { - "type": "array", - "items": {"type": "string"} - } - }, - "required": ["id", "name", "functions"] - } - }, - "properties": { - "id": {"type": "number"}, - "work": { - "type": "object", - "properties": { - "id": {"type": "number"}, - "name": {"type": "string"}, - "composer": {"$ref": "#/definitions/artist"} - } - }, - "recording_artists": { - "type": "array", - "items": {"$ref": "#/definitions/artist"} - } - }, - "required": ["id", "work", "recording_artists"] -}""" - -schemas = dict(simple_schema=simple_schema, complex_schema=complex_schema) - - -class JsonSchemaBenchmark: - params = schemas.keys() - - def setup(self, schema_name): - self.tokenizer = setup_tokenizer() - self.schema = schemas[schema_name] - - def time_json_schema_to_regex(self, schema_name): - build_regex_from_schema(self.schema) - - def time_json_schema_to_fsm(self, schema_name): - regex = build_regex_from_schema(self.schema) - RegexGuide(regex, self.tokenizer) diff --git a/bindings/python/benchmarks/bench_numba_compile.py b/bindings/python/benchmarks/bench_numba_compile.py deleted file mode 100644 index 6e479294..00000000 --- a/bindings/python/benchmarks/bench_numba_compile.py +++ /dev/null @@ -1,31 +0,0 @@ -import importlib - -import interegular -import numba -from outlines_core.fsm import regex - -from .common import setup_tokenizer - - -class NumbaCompileBenchmark: - def setup(self): - self.tokenizer = setup_tokenizer() - self.regex = regex - original_njit = numba.njit - - def mock_njit(*args, **kwargs): - kwargs["cache"] = False - return original_njit(*args, **kwargs) - - self.original_njit = original_njit - numba.njit = mock_njit - importlib.reload(self.regex) - self.regex_pattern, _ = self.regex.make_deterministic_fsm( - interegular.parse_pattern("a").to_fsm().reduce() - ) - - def teardown(self): - numba.njit = self.original_njit - - def time_compile_numba(self): - self.regex.create_fsm_index_tokenizer(self.regex_pattern, self.tokenizer) diff --git a/bindings/python/benchmarks/bench_regex_guide.py b/bindings/python/benchmarks/bench_regex_guide.py deleted file mode 100644 index 287d5f51..00000000 --- a/bindings/python/benchmarks/bench_regex_guide.py +++ /dev/null @@ -1,37 +0,0 @@ -from outlines_core.fsm.guide import RegexGuide - -from .common import setup_tokenizer - -regex_samples = { - "email": r"[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?", - "complex_phone": "\\+?\\d{1,4}?[-.\\s]?\\(?\\d{1,3}?\\)?[-.\\s]?\\d{1,4}[-.\\s]?\\d{1,4}[-.\\s]?\\d{1,9}", - "simple_phone": "\\+?[1-9][0-9]{7,14}", - "date": r"([1-9]|0[1-9]|1[0-9]|2[0-9]|3[0-1])(\.|-|/)([1-9]|0[1-9]|1[0-2])(\.|-|/)([0-9][0-9]|19[0-9][0-9]|20[0-9][0-9])|([0-9][0-9]|19[0-9][0-9]|20[0-9][0-9])(\.|-|/)([1-9]|0[1-9]|1[0-2])(\.|-|/)([1-9]|0[1-9]|1[0-9]|2[0-9]|3[0-1])", - "time": r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?", - "ip": r"(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)", - "url": r"(https?:\/\/)?([\da-z\.-]+)\.([a-z\.]{2,6})([\/\w \.-]*)*\/?", - "ssn": r"\d{3}-\d{2}-\d{4}", - "complex_span_constrained_relation_extraction": "(['\"\\ ,]?((?:of|resulting|case|which|cultures|a|core|extreme|selflessness|spiritual|various|However|both|vary|in|other|secular|the|religious|among|moral|and|It|object|worldviews|altruism|traditional|material|aspect|or|life|beings|virtue|is|however|opposite|concern|an|practice|it|for|s|quality|religions|In|Altruism|animals|happiness|many|become|principle|human|selfishness|may|synonym)['\"\\ ,]?)+['\"\\ ,]?\\s\\|\\s([^|\\(\\)\n]{1,})\\s\\|\\s['\"\\ ,]?((?:of|resulting|case|which|cultures|a|core|extreme|selflessness|spiritual|various|However|both|vary|in|other|secular|the|religious|among|moral|and|It|object|worldviews|altruism|traditional|material|aspect|or|life|beings|virtue|is|however|opposite|concern|an|practice|it|for|s|quality|religions|In|Altruism|animals|happiness|many|become|principle|human|selfishness|may|synonym)['\"\\ ,]?)+['\"\\ ,]?(\\s\\|\\s\\(([^|\\(\\)\n]{1,})\\s\\|\\s([^|\\(\\)\n]{1,})\\))*\\n)*", -} - - -class RegexGuideBenchmark: - params = regex_samples.keys() - - def setup(self, pattern_name): - self.tokenizer = setup_tokenizer() - self.pattern = regex_samples[pattern_name] - - def time_regex_to_guide(self, pattern_name): - RegexGuide(self.pattern, self.tokenizer) - - -class MemoryRegexGuideBenchmark: - params = ["simple_phone", "complex_span_constrained_relation_extraction"] - - def setup(self, pattern_name): - self.tokenizer = setup_tokenizer() - self.pattern = regex_samples[pattern_name] - - def peakmem_regex_to_guide(self, pattern_name): - RegexGuide(self.pattern, self.tokenizer) diff --git a/bindings/python/benchmarks/common.py b/bindings/python/benchmarks/common.py deleted file mode 100644 index d4dbec85..00000000 --- a/bindings/python/benchmarks/common.py +++ /dev/null @@ -1,8 +0,0 @@ -from outlines_core.fsm.guide import RegexGuide -from outlines_core.models.transformers import TransformerTokenizer -from transformers import AutoTokenizer - - -def setup_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained("gpt2") - return TransformerTokenizer(tokenizer) diff --git a/python/outlines_core/__init__.py b/bindings/python/outlines_core/__init__.py similarity index 53% rename from python/outlines_core/__init__.py rename to bindings/python/outlines_core/__init__.py index ed2d5a9c..f4c0eadc 100644 --- a/python/outlines_core/__init__.py +++ b/bindings/python/outlines_core/__init__.py @@ -1,4 +1,5 @@ """Outlines is a Generative Model Programming Framework.""" import outlines_core.models +from _outlines_core_rs import __version__ -__all__ = ["models"] +__all__ = ["models", "__version__"] \ No newline at end of file diff --git a/bindings/python/src/outlines_core/fsm/__init__.py b/bindings/python/outlines_core/fsm/__init__.py similarity index 100% rename from bindings/python/src/outlines_core/fsm/__init__.py rename to bindings/python/outlines_core/fsm/__init__.py diff --git a/python/outlines_core/fsm/fsm.py b/bindings/python/outlines_core/fsm/fsm.py similarity index 96% rename from python/outlines_core/fsm/fsm.py rename to bindings/python/outlines_core/fsm/fsm.py index 4daf3c86..e28b86ab 100644 --- a/python/outlines_core/fsm/fsm.py +++ b/bindings/python/outlines_core/fsm/fsm.py @@ -44,4 +44,4 @@ def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: return next_instruction.tokens def next_state(self, state: FSMState, token_id: int) -> FSMState: - return FSMState(self.get_next_state(state, token_id)) + return FSMState(self.get_next_state(state, token_id)) \ No newline at end of file diff --git a/bindings/python/src/outlines_core/fsm/guide.py b/bindings/python/outlines_core/fsm/guide.py similarity index 99% rename from bindings/python/src/outlines_core/fsm/guide.py rename to bindings/python/outlines_core/fsm/guide.py index 8f7250ef..774f1d54 100644 --- a/bindings/python/src/outlines_core/fsm/guide.py +++ b/bindings/python/outlines_core/fsm/guide.py @@ -292,4 +292,4 @@ def is_final_state(self, state: int) -> bool: return state in self.final_states def copy(self): - return self + return self \ No newline at end of file diff --git a/bindings/python/src/outlines_core/fsm/json_schema.py b/bindings/python/outlines_core/fsm/json_schema.py similarity index 99% rename from bindings/python/src/outlines_core/fsm/json_schema.py rename to bindings/python/outlines_core/fsm/json_schema.py index b2924300..4145afa0 100644 --- a/bindings/python/src/outlines_core/fsm/json_schema.py +++ b/bindings/python/outlines_core/fsm/json_schema.py @@ -516,4 +516,4 @@ def get_schema_from_signature(fn: Callable) -> str: ) model = create_model(fn_name, **arguments) - return model.model_json_schema() + return model.model_json_schema() \ No newline at end of file diff --git a/python/outlines_core/fsm/regex.py b/bindings/python/outlines_core/fsm/regex.py similarity index 99% rename from python/outlines_core/fsm/regex.py rename to bindings/python/outlines_core/fsm/regex.py index 834b5880..20c42d38 100644 --- a/python/outlines_core/fsm/regex.py +++ b/bindings/python/outlines_core/fsm/regex.py @@ -22,8 +22,7 @@ _AnythingElseCls, anything_else, ) - -from .outlines_core_rs import ( # noqa: F401 +from _outlines_core_rs import ( # noqa: F401 FSMInfo, _walk_fsm, create_fsm_index_end_to_end, @@ -480,4 +479,4 @@ def create_fsm_index_tokenizer( if subset is not None: subset[tokenizer.eos_token_id] = state - return states_to_token_subsets, empty_token_ids + return states_to_token_subsets, empty_token_ids \ No newline at end of file diff --git a/python/outlines_core/fsm/types.py b/bindings/python/outlines_core/fsm/types.py similarity index 99% rename from python/outlines_core/fsm/types.py rename to bindings/python/outlines_core/fsm/types.py index 5695dee0..860202e4 100644 --- a/python/outlines_core/fsm/types.py +++ b/bindings/python/outlines_core/fsm/types.py @@ -78,4 +78,4 @@ def datetime_format_fn(sequence: str) -> datetime.datetime: else: raise NotImplementedError( f"The Python type {python_type} is not supported. Please open an issue." - ) + ) \ No newline at end of file diff --git a/python/outlines_core/integrations/utils.py b/bindings/python/outlines_core/integrations/utils.py similarity index 98% rename from python/outlines_core/integrations/utils.py rename to bindings/python/outlines_core/integrations/utils.py index 67c70685..500a27cd 100644 --- a/python/outlines_core/integrations/utils.py +++ b/bindings/python/outlines_core/integrations/utils.py @@ -38,4 +38,4 @@ def convert_token_to_string(token: Union[str, bytes]) -> str: tokenizer.convert_token_to_string = convert_token_to_string - return tokenizer + return tokenizer \ No newline at end of file diff --git a/python/outlines_core/models/__init__.py b/bindings/python/outlines_core/models/__init__.py similarity index 90% rename from python/outlines_core/models/__init__.py rename to bindings/python/outlines_core/models/__init__.py index c6277f62..7b1afd13 100644 --- a/python/outlines_core/models/__init__.py +++ b/bindings/python/outlines_core/models/__init__.py @@ -10,4 +10,4 @@ from .transformers import Transformers, TransformerTokenizer, mamba, transformers -LogitsGenerator = Union[Transformers] +LogitsGenerator = Union[Transformers] \ No newline at end of file diff --git a/python/outlines_core/models/tokenizer.py b/bindings/python/outlines_core/models/tokenizer.py similarity index 98% rename from python/outlines_core/models/tokenizer.py rename to bindings/python/outlines_core/models/tokenizer.py index 1a5708d8..addd5f63 100644 --- a/python/outlines_core/models/tokenizer.py +++ b/bindings/python/outlines_core/models/tokenizer.py @@ -28,4 +28,4 @@ def convert_token_to_string(self, token: str) -> str: represented by the special characted `Δ `. This prevents matching a raw token that includes `Δ ` with a string. """ - ... + ... \ No newline at end of file diff --git a/python/outlines_core/models/transformers.py b/bindings/python/outlines_core/models/transformers.py similarity index 99% rename from python/outlines_core/models/transformers.py rename to bindings/python/outlines_core/models/transformers.py index bc5ba7b6..e9cecb7b 100644 --- a/python/outlines_core/models/transformers.py +++ b/bindings/python/outlines_core/models/transformers.py @@ -471,4 +471,4 @@ def mamba( model_kwargs=model_kwargs, tokenizer_kwargs=tokenizer_kwargs, model_class=MambaForCausalLM, - ) + ) \ No newline at end of file diff --git a/bindings/python/src/outlines_core/py.typed b/bindings/python/outlines_core/py.typed similarity index 100% rename from bindings/python/src/outlines_core/py.typed rename to bindings/python/outlines_core/py.typed diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml index 356cb06f..13e1385c 100644 --- a/bindings/python/pyproject.toml +++ b/bindings/python/pyproject.toml @@ -1,128 +1,15 @@ [build-system] -requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2", "setuptools-rust"] -build-backend = "setuptools.build_meta" +requires = ["maturin>=1.7,<2.0"] +build-backend = "maturin" [project] name = "outlines_core" -authors = [ - { name = "Outlines Developers" }, - { name = "Huggingface Developers" }, -] -description = "Structured Text Generation in Rust" requires-python = ">=3.8" -license = { text = "Apache-2.0" } -keywords = [ - "machine learning", - "deep learning", - "language models", - "structured generation", -] classifiers = [ - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Developers", - "Intended Audience :: Information Technology", - "Intended Audience :: Science/Research", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3", - "Topic :: Scientific/Engineering :: Artificial Intelligence", -] -dependencies = [ - "interegular", - "numpy<2.0.0", - "cloudpickle", - "diskcache", - "pydantic>=2.0", - "numba", - "referencing", - "jsonschema", - "tqdm", - "datasets", - "typing_extensions", -] -dynamic = ["version"] - -[project.optional-dependencies] -test = [ - "pre-commit", - "pytest", - "pytest-benchmark", - "pytest-cov", - "pytest-mock", - "coverage[toml]>=5.1", - "diff-cover", - "accelerate", - "beartype<0.16.0", - "huggingface_hub", - "torch", - "transformers", - "pillow", -] - -[project.urls] -homepage = "https://github.com/outlines-dev/outlines-core" -documentation = "https://outlines-dev.github.io/outlines-core/" -repository = "https://github.com/outlines-dev/outlines-core/" - -[project.readme] -file = "README.md" -content-type = "text/markdown" - -[tool.setuptools] -packages = ["outlines_core"] -package-dir = {"" = "src"} - -[tool.setuptools.package-data] -"outlines" = ["py.typed"] - -[tool.setuptools_scm] -write_to = "src/outlines_core/_version.py" - -[tool.pytest.ini_options] -testpaths = ["tests"] -filterwarnings = [ - "error", - "ignore::numba.core.errors.NumbaPendingDeprecationWarning", - "ignore::pydantic.warnings.PydanticDeprecatedSince20", - "ignore::FutureWarning:transformers.*", - "ignore::FutureWarning:huggingface_hub.*", - "ignore::UserWarning", - "ignore::DeprecationWarning:pyairports.*", + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", ] -[tool.mypy] -exclude=["examples"] -enable_incomplete_feature = ["Unpack"] - -[[tool.mypy.overrides]] -module = [ - "jsonschema.*", - "numpy.*", - "cloudpickle.*", - "diskcache.*", - "pydantic.*", - "pytest", - "referencing.*", - "torch.*", - "transformers.*", - "huggingface_hub", - "interegular.*", - "datasets.*", - "numba.*", -] -ignore_missing_imports = true - -[tool.coverage.run] -omit = ["src/outlines_core/_version.py", "tests/*"] -branch = true - -[tool.coverage.report] -omit = ["tests/*"] -exclude_lines = ["pragma: no cover", "if TYPE_CHECKING:", "\\.\\.\\."] -show_missing = true - -[tool.diff_cover] -compare_branch = "origin/main" -diff_range_notation = ".." - -[[tool.setuptools-rust.ext-modules]] -target = "outlines_core._lib" # The last part of the name (e.g. "_lib") has to match lib.name in Cargo.toml, but you can add a prefix to nest it inside of a Python package. +[tool.maturin] +features = ["pyo3/extension-module"] diff --git a/bindings/python/rust/lib.rs b/bindings/python/rust/lib.rs deleted file mode 100644 index 0b0523bc..00000000 --- a/bindings/python/rust/lib.rs +++ /dev/null @@ -1,12 +0,0 @@ -use pyo3::pymodule; - -#[pymodule] -mod _lib { - use outlines_core as core_lib; - use pyo3::{pyfunction, PyResult}; - - #[pyfunction] - fn hello() -> PyResult { - Ok(core_lib::hello()) - } -} diff --git a/bindings/python/setup.cfg b/bindings/python/setup.cfg deleted file mode 100644 index c8153801..00000000 --- a/bindings/python/setup.cfg +++ /dev/null @@ -1,11 +0,0 @@ -[flake8] -max-line-length = 88 -select = C,E,F,W -ignore = E203,E231,E501,E741,W503,W504,C901,E731 -per-file-ignores = - **/__init__.py:F401,F403 -exclude = - normalai/_version.py - -[bdist_wheel] -py_limited_api=cp38 diff --git a/bindings/python/src/regex.rs b/bindings/python/src/lib.rs similarity index 50% rename from bindings/python/src/regex.rs rename to bindings/python/src/lib.rs index df7d36f6..556cb630 100644 --- a/bindings/python/src/regex.rs +++ b/bindings/python/src/lib.rs @@ -1,182 +1,26 @@ -use pyo3::prelude::*; -use pyo3::types::PyDict; +use ::outlines_core as core_lib; +use pyo3::{pyclass, pyfunction, pymethods, pymodule, types::{PyAnyMethods, PyDict, PyModule, PyModuleMethods}, wrap_pyfunction, Bound, PyResult, Python}; use std::collections::{HashMap, HashSet}; -pub fn walk_fsm_internal( - fsm_transitions: &HashMap<(u32, u32), u32>, - _fsm_initial: u32, - fsm_finals: &HashSet, - token_transition_keys: &[u32], - start_state: u32, - full_match: bool, -) -> Vec { - let mut state = start_state; - let mut accepted_states = Vec::new(); - let mut last_final_idx = 0; - - for (i, &trans_key) in token_transition_keys.iter().enumerate() { - match fsm_transitions.get(&(state, trans_key)) { - Some(&new_state) => { - state = new_state; - if fsm_finals.contains(&state) { - last_final_idx = i + 1; - } - accepted_states.push(state); - } - None => { - if !full_match && last_final_idx > 0 { - return accepted_states[..last_final_idx].to_vec(); - } - return Vec::new(); - } - } - } +#[pymodule] +fn _outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(_walk_fsm, m)?)?; + m.add_function(wrap_pyfunction!(state_scan_tokens, m)?)?; + m.add_function(wrap_pyfunction!(get_token_transition_keys, m)?)?; + m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys, m)?)?; + m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end, m)?)?; + m.add_function(wrap_pyfunction!(flag, m)?)?; - if full_match && last_final_idx != token_transition_keys.len() { - return Vec::new(); - } + m.add_class::()?; - accepted_states + m.add("__version__", env!("CARGO_PKG_VERSION"))?; + Ok(()) } -pub fn state_scan_tokens_internal( - fsm_transitions: &HashMap<(u32, u32), u32>, - fsm_initial: u32, - fsm_finals: &HashSet, - vocabulary: &[(String, Vec)], - vocabulary_transition_keys: &[Vec], - start_state: u32, -) -> HashSet<(u32, u32)> { - let mut res = HashSet::new(); - - for (vocab_item, token_transition_keys) in - vocabulary.iter().zip(vocabulary_transition_keys.iter()) - { - let token_ids: Vec = vocab_item.1.clone(); - - let state_seq = walk_fsm_internal( - fsm_transitions, - fsm_initial, - fsm_finals, - token_transition_keys, - start_state, - false, - ); - - if state_seq.len() < token_transition_keys.len() { - continue; - } - - for &token_id in &token_ids { - res.insert((token_id, *state_seq.last().unwrap())); - } - } - - res -} - -pub fn get_token_transition_keys_internal( - alphabet_symbol_mapping: &HashMap, - alphabet_anything_value: u32, - token_str: &str, -) -> Vec { - let mut token_transition_keys = Vec::new(); - let mut i = 0; - let chars: Vec = token_str.chars().collect(); - - while i < chars.len() { - let symbol; - if chars[i] == '\0' && i != chars.len() - 1 { - if i + 2 < chars.len() { - symbol = format!("\0{}{}", chars[i + 1], chars[i + 2]); - i += 3; - } else { - symbol = chars[i].to_string(); - i += 1; - } - } else { - symbol = chars[i].to_string(); - i += 1; - } - - let transition_key = *alphabet_symbol_mapping - .get(&symbol) - .unwrap_or(&alphabet_anything_value); - token_transition_keys.push(transition_key); - } - - token_transition_keys -} - -pub fn get_vocabulary_transition_keys_internal( - alphabet_symbol_mapping: &HashMap, - alphabet_anything_value: u32, - vocabulary: &[(String, Vec)], - frozen_tokens: &HashSet, -) -> Vec> { - let mut vocab_transition_keys: Vec> = Vec::new(); - - for item in vocabulary.iter() { - let token_str = item.0.clone(); - - let mut token_transition_keys; - - // Since these tokens are not expanded into byte-level transitions, we - // can simply get their transition keys directly. - if frozen_tokens.contains(&token_str) { - token_transition_keys = Vec::new(); - token_transition_keys.push( - *alphabet_symbol_mapping - .get(&token_str) - .unwrap_or(&alphabet_anything_value), - ) - } else { - token_transition_keys = get_token_transition_keys_internal( - alphabet_symbol_mapping, - alphabet_anything_value, - &token_str, - ); - } - - vocab_transition_keys.push(token_transition_keys); - } - - vocab_transition_keys -} - -#[pyclass] -pub struct FSMInfo { - #[pyo3(get)] - initial: u32, - #[pyo3(get)] - finals: HashSet, - #[pyo3(get)] - transitions: HashMap<(u32, u32), u32>, - #[pyo3(get)] - alphabet_anything_value: u32, - #[pyo3(get)] - alphabet_symbol_mapping: HashMap, -} - -#[pymethods] -impl FSMInfo { - #[new] - fn new( - initial: u32, - finals: HashSet, - transitions: HashMap<(u32, u32), u32>, - alphabet_anything_value: u32, - alphabet_symbol_mapping: HashMap, - ) -> Self { - Self { - initial, - finals, - transitions, - alphabet_anything_value, - alphabet_symbol_mapping, - } - } +#[pyfunction(name = "flag")] +pub fn flag(flag: bool) -> PyResult { + Ok(flag) } #[pyfunction(name = "_walk_fsm")] @@ -191,7 +35,7 @@ pub fn _walk_fsm( start_state: u32, full_match: bool, ) -> PyResult> { - Ok(walk_fsm_internal( + Ok(core_lib::regex::walk_fsm_internal( &fsm_transitions, fsm_initial, &fsm_finals, @@ -213,7 +57,7 @@ pub fn state_scan_tokens( vocabulary_transition_keys: Vec>, start_state: u32, ) -> PyResult> { - Ok(state_scan_tokens_internal( + Ok(core_lib::regex::state_scan_tokens_internal( &fsm_transitions, fsm_initial, &fsm_finals, @@ -230,7 +74,7 @@ pub fn get_token_transition_keys( alphabet_anything_value: u32, token_str: String, ) -> PyResult> { - Ok(get_token_transition_keys_internal( + Ok(core_lib::regex::get_token_transition_keys_internal( &alphabet_symbol_mapping, alphabet_anything_value, &token_str, @@ -247,7 +91,7 @@ pub fn get_vocabulary_transition_keys( vocabulary: Vec<(String, Vec)>, frozen_tokens: HashSet, ) -> PyResult>> { - Ok(get_vocabulary_transition_keys_internal( + Ok(core_lib::regex::get_vocabulary_transition_keys_internal( &alphabet_symbol_mapping, alphabet_anything_value, &vocabulary, @@ -268,7 +112,7 @@ pub fn create_fsm_index_end_to_end<'py>( let mut seen: HashSet = HashSet::new(); let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); - let vocabulary_transition_keys = get_vocabulary_transition_keys_internal( + let vocabulary_transition_keys = core_lib::regex::get_vocabulary_transition_keys_internal( &fsm_info.alphabet_symbol_mapping, fsm_info.alphabet_anything_value, &vocabulary, @@ -279,7 +123,7 @@ pub fn create_fsm_index_end_to_end<'py>( next_states.remove(&start_state); // TODO: Return Pydict directly at construction - let token_ids_end_states = state_scan_tokens_internal( + let token_ids_end_states = core_lib::regex::state_scan_tokens_internal( &fsm_info.transitions, fsm_info.initial, &fsm_info.finals, @@ -287,9 +131,9 @@ pub fn create_fsm_index_end_to_end<'py>( &vocabulary_transition_keys, start_state, ); - + for (token_id, end_state) in token_ids_end_states { - if let Ok(Some(existing_dict)) = states_to_token_subsets.get_item(start_state) { + if let Ok(existing_dict) = states_to_token_subsets.get_item(start_state) { existing_dict.set_item(token_id, end_state).unwrap(); } else { let new_dict = PyDict::new_bound(py); @@ -309,3 +153,37 @@ pub fn create_fsm_index_end_to_end<'py>( Ok(states_to_token_subsets) } + +#[pyclass] +pub struct FSMInfo { + #[pyo3(get)] + initial: u32, + #[pyo3(get)] + finals: HashSet, + #[pyo3(get)] + transitions: HashMap<(u32, u32), u32>, + #[pyo3(get)] + alphabet_anything_value: u32, + #[pyo3(get)] + alphabet_symbol_mapping: HashMap, +} + +#[pymethods] +impl FSMInfo { + #[new] + fn new( + initial: u32, + finals: HashSet, + transitions: HashMap<(u32, u32), u32>, + alphabet_anything_value: u32, + alphabet_symbol_mapping: HashMap, + ) -> Self { + Self { + initial, + finals, + transitions, + alphabet_anything_value, + alphabet_symbol_mapping, + } + } +} diff --git a/bindings/python/src/outlines_core/__init__.py b/bindings/python/src/outlines_core/__init__.py deleted file mode 100644 index a8b21aaa..00000000 --- a/bindings/python/src/outlines_core/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Outlines is a Generative Model Programming Framework.""" - -from ._lib import hello - -__all__ = ["models", "hello"] diff --git a/bindings/python/src/outlines_core/fsm/fsm.py b/bindings/python/src/outlines_core/fsm/fsm.py deleted file mode 100644 index 4daf3c86..00000000 --- a/bindings/python/src/outlines_core/fsm/fsm.py +++ /dev/null @@ -1,47 +0,0 @@ -import warnings -from typing import TYPE_CHECKING, Iterable, NewType, Optional - -from outlines_core.fsm.guide import RegexGuide, StopAtEOSGuide - -if TYPE_CHECKING: - from outlines_core.models.tokenizer import Tokenizer - -FSMState = NewType("FSMState", int) - - -class StopAtEosFSM(StopAtEOSGuide): - """FSM to generate text until EOS has been generated.""" - - def __init__(self, tokenizer: "Tokenizer"): - warnings.warn( - UserWarning( - "The `StopAtTokenFSM` interface is deprecated and will be removed on 2024-06-01. Please use `StopAtEOSGuide` instead." - ) - ) - super().__init__(tokenizer) - - def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: - next_instruction = self.get_next_instruction(state) - return next_instruction.tokens - - def next_state(self, state: FSMState, token_id: int) -> FSMState: - return FSMState(self.get_next_state(state, token_id)) - - -class RegexFSM(RegexGuide): - """FSM to generate text that is in the language of a regular expression.""" - - def __init__(self, regex_string: str, tokenizer): - warnings.warn( - UserWarning( - "The `RegexFSM` interface is deprecated and will be removed on 2024-06-01. Please use `RegexGuide` instead." - ) - ) - super().__init__(regex_string, tokenizer) - - def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: - next_instruction = self.get_next_instruction(state) - return next_instruction.tokens - - def next_state(self, state: FSMState, token_id: int) -> FSMState: - return FSMState(self.get_next_state(state, token_id)) diff --git a/bindings/python/src/outlines_core/fsm/regex.py b/bindings/python/src/outlines_core/fsm/regex.py deleted file mode 100644 index e69de29b..00000000 diff --git a/bindings/python/src/outlines_core/fsm/types.py b/bindings/python/src/outlines_core/fsm/types.py deleted file mode 100644 index 5695dee0..00000000 --- a/bindings/python/src/outlines_core/fsm/types.py +++ /dev/null @@ -1,81 +0,0 @@ -import datetime -from enum import EnumMeta -from typing import Any, Protocol, Tuple, Type - -from typing_extensions import _AnnotatedAlias, get_args - -INTEGER = r"[+-]?(0|[1-9][0-9]*)" -BOOLEAN = "(True|False)" -FLOAT = rf"{INTEGER}(\.[0-9]+)?([eE][+-][0-9]+)?" -DATE = r"(\d{4})-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])" -TIME = r"([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])" -DATETIME = rf"({DATE})(\s)({TIME})" - - -class FormatFunction(Protocol): - def __call__(self, sequence: str) -> Any: - ... - - -def python_types_to_regex(python_type: Type) -> Tuple[str, FormatFunction]: - # If it is a custom type - if isinstance(python_type, _AnnotatedAlias): - json_schema = get_args(python_type)[1].json_schema - type_class = get_args(python_type)[0] - - custom_regex_str = json_schema["pattern"] - - def custom_format_fn(sequence: str) -> Any: - return type_class(sequence) - - return custom_regex_str, custom_format_fn - - if isinstance(python_type, EnumMeta): - values = python_type.__members__.keys() - enum_regex_str: str = "(" + "|".join(values) + ")" - - def enum_format_fn(sequence: str) -> str: - return str(sequence) - - return enum_regex_str, enum_format_fn - - if python_type == float: - - def float_format_fn(sequence: str) -> float: - return float(sequence) - - return FLOAT, float_format_fn - elif python_type == int: - - def int_format_fn(sequence: str) -> int: - return int(sequence) - - return INTEGER, int_format_fn - elif python_type == bool: - - def bool_format_fn(sequence: str) -> bool: - return bool(sequence) - - return BOOLEAN, bool_format_fn - elif python_type == datetime.date: - - def date_format_fn(sequence: str) -> datetime.date: - return datetime.datetime.strptime(sequence, "%Y-%m-%d").date() - - return DATE, date_format_fn - elif python_type == datetime.time: - - def time_format_fn(sequence: str) -> datetime.time: - return datetime.datetime.strptime(sequence, "%H:%M:%S").time() - - return TIME, time_format_fn - elif python_type == datetime.datetime: - - def datetime_format_fn(sequence: str) -> datetime.datetime: - return datetime.datetime.strptime(sequence, "%Y-%m-%d %H:%M:%S") - - return DATETIME, datetime_format_fn - else: - raise NotImplementedError( - f"The Python type {python_type} is not supported. Please open an issue." - ) diff --git a/bindings/python/src/outlines_core/integrations/utils.py b/bindings/python/src/outlines_core/integrations/utils.py deleted file mode 100644 index 67c70685..00000000 --- a/bindings/python/src/outlines_core/integrations/utils.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Union - -from transformers import SPIECE_UNDERLINE, PreTrainedTokenizerBase - - -def adapt_tokenizer(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBase: - """Adapt a tokenizer to use to compile the FSM. - - The API of Outlines tokenizers is slightly different to that of `transformers`. In - addition we need to handle the missing spaces to Llama's tokenizer to be able to - compile FSMs for this model. - - Parameters - ---------- - tokenizer - The tokenizer of the model. - - Returns - ------- - PreTrainedTokenizerBase - The adapted tokenizer. - """ - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) - - def convert_token_to_string(token: Union[str, bytes]) -> str: - string = tokenizer.convert_tokens_to_string([token]) - - # A hack to handle missing spaces to HF's Llama tokenizers - if ( - type(token) is str - and token.startswith(SPIECE_UNDERLINE) - or token == "<0x20>" - ): - return " " + string - - return string - - tokenizer.convert_token_to_string = convert_token_to_string - - return tokenizer diff --git a/bindings/python/src/outlines_core/models/__init__.py b/bindings/python/src/outlines_core/models/__init__.py deleted file mode 100644 index c6277f62..00000000 --- a/bindings/python/src/outlines_core/models/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Module that contains all the models integrated in outlines. - -We group the models in submodules by provider instead of theme (completion, chat -completion, diffusers, etc.) and use routing functions everywhere else in the -codebase. - -""" - -from typing import Union - -from .transformers import Transformers, TransformerTokenizer, mamba, transformers - -LogitsGenerator = Union[Transformers] diff --git a/bindings/python/src/outlines_core/models/tokenizer.py b/bindings/python/src/outlines_core/models/tokenizer.py deleted file mode 100644 index 1a5708d8..00000000 --- a/bindings/python/src/outlines_core/models/tokenizer.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Dict, Hashable, List, Protocol, Set, Tuple, Union - -import numpy as np -from numpy.typing import NDArray - - -class Tokenizer(Hashable, Protocol): - eos_token: str - eos_token_id: int - pad_token_id: int - vocabulary: Dict[str, int] - special_tokens: Set[str] - - def encode( - self, prompt: Union[str, List[str]] - ) -> Tuple[NDArray[np.int64], NDArray[np.int64]]: - """Translate the input prompts into arrays of token ids and attention mask.""" - ... - - def decode(self, token_ids: NDArray[np.int64]) -> List[str]: - """Translate an array of token ids to a string or list of strings.""" - ... - - def convert_token_to_string(self, token: str) -> str: - """Convert a token to its equivalent string. - - This is for instance useful for BPE tokenizers where whitespaces are - represented by the special characted `Δ `. This prevents matching a raw - token that includes `Δ ` with a string. - """ - ... diff --git a/bindings/python/src/outlines_core/models/transformers.py b/bindings/python/src/outlines_core/models/transformers.py deleted file mode 100644 index bc5ba7b6..00000000 --- a/bindings/python/src/outlines_core/models/transformers.py +++ /dev/null @@ -1,474 +0,0 @@ -import dataclasses -import inspect -from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union - -from datasets.fingerprint import Hasher -from outlines_core.models.tokenizer import Tokenizer - -if TYPE_CHECKING: - import torch - from transformers import PreTrainedModel, PreTrainedTokenizer - -__all__ = ["transformers"] - - -KVCacheType = Tuple[Tuple["torch.DoubleTensor", "torch.DoubleTensor"], ...] - - -@dataclasses.dataclass(frozen=True) -class GenerationParameters: - """Generation parameters used in Outlines' public API.""" - - max_tokens: Optional[int] - stop_at: Optional[Union[str, List[str]]] - seed: Optional[int] - - -@dataclasses.dataclass(frozen=True) -class SamplingParameters: - """Sampling parameters available in Outlines.""" - - sampler: str - num_samples: int = 1 - top_p: Optional[float] = None - top_k: Optional[int] = None - temperature: Optional[float] = None - - -def get_llama_tokenizer_types(): - """Get all the Llama tokenizer types/classes that need work-arounds. - - When they can't be imported, a dummy class is created. - - """ - try: - from transformers.models.llama import LlamaTokenizer - except ImportError: - - class LlamaTokenizer: # type: ignore - pass - - try: - from transformers.models.llama import LlamaTokenizerFast - except ImportError: - - class LlamaTokenizerFast: # type: ignore - pass - - try: - from transformers.models.code_llama import CodeLlamaTokenizer - except ImportError: - - class CodeLlamaTokenizer: # type: ignore - pass - - try: - from transformers.models.code_llama import CodeLlamaTokenizerFast - except ImportError: - - class CodeLlamaTokenizerFast: # type: ignore - pass - - return ( - LlamaTokenizer, - LlamaTokenizerFast, - CodeLlamaTokenizer, - CodeLlamaTokenizerFast, - ) - - -class TransformerTokenizer(Tokenizer): - """Represents a tokenizer for models in the `transformers` library.""" - - def __init__(self, tokenizer: "PreTrainedTokenizer", **kwargs): - self.tokenizer = tokenizer - self.eos_token_id = self.tokenizer.eos_token_id - self.eos_token = self.tokenizer.eos_token - - if self.tokenizer.pad_token_id is None: - self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - self.pad_token_id = self.eos_token_id - else: - self.pad_token_id = self.tokenizer.pad_token_id - self.pad_token = self.tokenizer.pad_token - - self.special_tokens = set(self.tokenizer.all_special_tokens) - - self.vocabulary = self.tokenizer.get_vocab() - self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types()) - - def encode( - self, prompt: Union[str, List[str]], **kwargs - ) -> Tuple["torch.LongTensor", "torch.LongTensor"]: - kwargs["padding"] = True - kwargs["return_tensors"] = "pt" - output = self.tokenizer(prompt, **kwargs) - return output["input_ids"], output["attention_mask"] - - def decode(self, token_ids: "torch.LongTensor") -> List[str]: - text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) - return text - - def convert_token_to_string(self, token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - - string = self.tokenizer.convert_tokens_to_string([token]) - - if self.is_llama: - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string - - def __eq__(self, other): - if isinstance(other, type(self)): - if hasattr(self, "model_name") and hasattr(self, "kwargs"): - return ( - other.model_name == self.model_name and other.kwargs == self.kwargs - ) - else: - return other.tokenizer == self.tokenizer - return NotImplemented - - def __hash__(self): - return hash(Hasher.hash(self.tokenizer)) - - def __getstate__(self): - state = {"tokenizer": self.tokenizer} - return state - - def __setstate__(self, state): - self.__init__(state["tokenizer"]) - - -class Transformers: - """Represents a `transformers` model.""" - - def __init__( - self, - model: "PreTrainedModel", - tokenizer: "PreTrainedTokenizer", - ): - self.model = model - self.tokenizer = TransformerTokenizer(tokenizer) - - def forward( - self, - input_ids: "torch.LongTensor", - attention_mask: "torch.LongTensor", - past_key_values: Optional[Tuple] = None, - ) -> Tuple["torch.FloatTensor", Optional[KVCacheType]]: - """Compute a forward pass through the transformer model. - - Parameters - ---------- - input_ids - The input token ids. Must be one or two dimensional. - attention_mask - The attention mask. Must be one or two dimensional. - past_key_values - A tuple of tuples containing the cached key and value tensors for each - attention head. - - Returns - ------- - The computed logits and the new cached key and value tensors. - - """ - try: - import torch - except ImportError: - ImportError( - "The `torch` library needs to be installed to use `transformers` models." - ) - assert 0 < input_ids.ndim < 3 - - if past_key_values: - input_ids = input_ids[..., -1].unsqueeze(-1) - - with torch.inference_mode(): - output = self.model( - input_ids, - attention_mask=attention_mask, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - past_key_values=past_key_values, - ) - - return output.logits, output.past_key_values - - def __call__( - self, - input_ids: "torch.LongTensor", - attention_mask: "torch.LongTensor", - past_key_values: Optional[Tuple] = None, - ) -> "torch.FloatTensor": - logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values) - next_token_logits = logits[..., -1, :] - - return next_token_logits, kv_cache - - def generate( - self, - prompts: Union[str, List[str]], - generation_parameters: GenerationParameters, - logits_processor, - sampling_parameters: SamplingParameters, - ) -> Union[str, List[str], List[List[str]]]: - """Generate text using `transformers`. - - Arguments - --------- - prompts - A prompt or list of prompts. - generation_parameters - An instance of `GenerationParameters` that contains the prompt, - the maximum number of tokens, stop sequences and seed. All the - arguments to `SequenceGeneratorAdapter`'s `__cal__` method. - logits_processor - The logits processor to use when generating text. - sampling_parameters - An instance of `SamplingParameters`, a dataclass that contains - the name of the sampler to use and related parameters as available - in Outlines. - - Returns - ------- - The generated text - """ - if isinstance(prompts, str): - # convert to 2d - input_ids, attention_mask = self.tokenizer.encode([prompts]) - else: - input_ids, attention_mask = self.tokenizer.encode(prompts) - - inputs = { - "input_ids": input_ids.to(self.model.device), - "attention_mask": attention_mask.to(self.model.device), - } - if ( - "attention_mask" - not in inspect.signature(self.model.forward).parameters.keys() - ): - del inputs["attention_mask"] - - generation_kwargs = self._get_generation_kwargs( - prompts, - generation_parameters, - logits_processor, - sampling_parameters, - ) - generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs) - - # if single str input and single sample per input, convert to a 1D output - if isinstance(prompts, str): - generated_ids = generated_ids.squeeze(0) - - return self._decode_generation(generated_ids) - - def stream( - self, - prompts: Union[str, List[str]], - generation_parameters: GenerationParameters, - logits_processor, - sampling_parameters: SamplingParameters, - ) -> Iterator[Union[str, List[str]]]: - """ - Temporary stream stand-in which implements stream() signature - and equivalent behaviour but isn't yielded until generation completes. - - TODO: implement following completion of https://github.com/huggingface/transformers/issues/30810 - """ - if isinstance(prompts, str): - # convert to 2d - input_ids, attention_mask = self.tokenizer.encode([prompts]) - else: - input_ids, attention_mask = self.tokenizer.encode(prompts) - inputs = { - "input_ids": input_ids.to(self.model.device), - "attention_mask": attention_mask.to(self.model.device), - } - if ( - "attention_mask" - not in inspect.signature(self.model.forward).parameters.keys() - ): - del inputs["attention_mask"] - - generation_kwargs = self._get_generation_kwargs( - prompts, - generation_parameters, - logits_processor, - sampling_parameters, - ) - generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs) - - # if single str input and single sample per input, convert to a 1D output - if isinstance(prompts, str): - generated_ids = generated_ids.squeeze(0) - - for i in range(generated_ids.size(-1)): - output_group_ids = generated_ids.select(-1, i).unsqueeze(-1) - yield self._decode_generation(output_group_ids) - - def _get_generation_kwargs( - self, - prompts: Union[str, List[str]], - generation_parameters: GenerationParameters, - logits_processor, - sampling_parameters: SamplingParameters, - ) -> dict: - """ - Conert outlines generation parameters into model.generate kwargs - """ - from transformers import GenerationConfig, LogitsProcessorList, set_seed - - max_new_tokens, stop_at, seed = dataclasses.astuple(generation_parameters) - sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple( - sampling_parameters - ) - if max_new_tokens is None: - max_new_tokens = int(2**30) - - # global seed, not desirable - if seed is not None: - set_seed(seed) - - if logits_processor is not None: - logits_processor_list = LogitsProcessorList([logits_processor]) - else: - logits_processor_list = None - - generation_config = GenerationConfig( - max_new_tokens=max_new_tokens, - stop_strings=stop_at, - num_return_sequences=(num_samples or 1), - top_p=top_p, - top_k=top_k, - temperature=temperature, - do_sample=(sampler == "multinomial"), - num_beams=(num_samples if sampler == "beam_search" else 1), - eos_token_id=self.tokenizer.eos_token_id, - pad_token_id=self.tokenizer.pad_token_id, - ) - - return dict( - logits_processor=logits_processor_list, - generation_config=generation_config, - tokenizer=self.tokenizer.tokenizer, - ) - - def _generate_output_seq( - self, prompts, inputs, generation_config, **generation_kwargs - ): - input_ids = inputs["input_ids"] - output_ids = self.model.generate( - **inputs, generation_config=generation_config, **generation_kwargs - ) - - # encoder-decoder returns output_ids only, decoder-only returns full seq ids - if self.model.config.is_encoder_decoder: - generated_ids = output_ids - else: - generated_ids = output_ids[:, input_ids.shape[1] :] - - # if batch list inputs AND multiple samples per input, convert generated_id to 3D view - num_samples = generation_config.num_return_sequences or 1 - - if num_samples > 1 and isinstance(prompts, list): - batch_size = input_ids.size(0) - num_return_sequences = generation_config.num_return_sequences or 1 - generated_ids = generated_ids.view(batch_size, num_return_sequences, -1) - - return generated_ids - - def _decode_generation(self, generated_ids: "torch.Tensor"): - if len(generated_ids.shape) == 1: - return self.tokenizer.decode([generated_ids])[0] - elif len(generated_ids.shape) == 2: - return self.tokenizer.decode(generated_ids) - elif len(generated_ids.shape) == 3: - return [ - self.tokenizer.decode(generated_ids[i]) - for i in range(len(generated_ids)) - ] - else: - raise TypeError( - f"Generated outputs aren't 1D, 2D or 3D, but instead are {generated_ids.shape}" - ) - - -def transformers( - model_name: str, - device: Optional[str] = None, - model_kwargs: dict = {}, - tokenizer_kwargs: dict = {}, - model_class=None, - tokenizer_class=None, -): - """Instantiate a model from the `transformers` library and its tokenizer. - - Parameters - ---------- - model_name - The name of the model as listed on Hugging Face's model page. - device - The device(s) on which the model should be loaded. This overrides - the `device_map` entry in `model_kwargs` when provided. - model_kwargs - A dictionary that contains the keyword arguments to pass to the - `from_pretrained` method when loading the model. - tokenizer_kwargs - A dictionary that contains the keyword arguments to pass to the - `from_pretrained` method when loading the tokenizer. - - Returns - ------- - A `TransformersModel` model instance. - - """ - if model_class is None or tokenizer_class is None: - try: - from transformers import AutoModelForCausalLM, AutoTokenizer - except ImportError: - raise ImportError( - "The `transformers` library needs to be installed in order to use `transformers` models." - ) - if model_class is None: - model_class = AutoModelForCausalLM - if tokenizer_class is None: - tokenizer_class = AutoTokenizer - - if device is not None: - model_kwargs["device_map"] = device - - model = model_class.from_pretrained(model_name, **model_kwargs) - - tokenizer_kwargs.setdefault("padding_side", "left") - tokenizer = tokenizer_class.from_pretrained(model_name, **tokenizer_kwargs) - - return Transformers(model, tokenizer) - - -def mamba( - model_name: str, - device: Optional[str] = None, - model_kwargs: dict = {}, - tokenizer_kwargs: dict = {}, -): - try: - from transformers import MambaForCausalLM - - except ImportError: - raise ImportError( - "The `mamba_ssm`, `torch` and `transformer` libraries needs to be installed in order to use Mamba." - ) - - return transformers( - model_name=model_name, - device=device, - model_kwargs=model_kwargs, - tokenizer_kwargs=tokenizer_kwargs, - model_class=MambaForCausalLM, - ) diff --git a/bindings/python/tests/__init__.py b/bindings/python/tests/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/bindings/python/tests/fsm/partial_python.lark b/bindings/python/tests/fsm/partial_python.lark deleted file mode 100644 index 973e5963..00000000 --- a/bindings/python/tests/fsm/partial_python.lark +++ /dev/null @@ -1,314 +0,0 @@ -// Python 3 grammar for Lark -// -// This grammar should parse all python 3.x code successfully. -// -// Adapted from: https://docs.python.org/3/reference/grammar.html -// -// This version is actually a subset of Lark's Python grammar without the -// regex look-arounds in the string terminals. -// -// Start symbols for the grammar: -// single_input is a single interactive statement; -// file_input is a module or sequence of commands read from an input file; -// eval_input is the input for the eval() functions. -// NB: compound_stmt in single_input is followed by extra NEWLINE! -// - -single_input: _NEWLINE | simple_stmt | compound_stmt _NEWLINE -file_input: (_NEWLINE | stmt)* -eval_input: testlist _NEWLINE* - -decorator: "@" dotted_name [ "(" [arguments] ")" ] _NEWLINE -decorators: decorator+ -decorated: decorators (classdef | funcdef | async_funcdef) - -async_funcdef: "async" funcdef -funcdef: "def" name "(" [parameters] ")" ["->" test] ":" suite - -parameters: paramvalue ("," paramvalue)* ["," SLASH ("," paramvalue)*] ["," [starparams | kwparams]] - | starparams - | kwparams - -SLASH: "/" // Otherwise the it will completely disappear and it will be undisguisable in the result -starparams: (starparam | starguard) poststarparams -starparam: "*" typedparam -starguard: "*" -poststarparams: ("," paramvalue)* ["," kwparams] -kwparams: "**" typedparam ","? - -?paramvalue: typedparam ("=" test)? -?typedparam: name (":" test)? - - -lambdef: "lambda" [lambda_params] ":" test -lambdef_nocond: "lambda" [lambda_params] ":" test_nocond -lambda_params: lambda_paramvalue ("," lambda_paramvalue)* ["," [lambda_starparams | lambda_kwparams]] - | lambda_starparams - | lambda_kwparams -?lambda_paramvalue: name ("=" test)? -lambda_starparams: "*" [name] ("," lambda_paramvalue)* ["," [lambda_kwparams]] -lambda_kwparams: "**" name ","? - - -?stmt: simple_stmt | compound_stmt -?simple_stmt: small_stmt (";" small_stmt)* [";"] _NEWLINE -?small_stmt: (expr_stmt | assign_stmt | del_stmt | pass_stmt | flow_stmt | import_stmt | global_stmt | nonlocal_stmt | assert_stmt) -expr_stmt: testlist_star_expr -assign_stmt: annassign | augassign | assign - -annassign: testlist_star_expr ":" test ["=" test] -assign: testlist_star_expr ("=" (yield_expr|testlist_star_expr))+ -augassign: testlist_star_expr augassign_op (yield_expr|testlist) -!augassign_op: "+=" | "-=" | "*=" | "@=" | "/=" | "%=" | "&=" | "|=" | "^=" | "<<=" | ">>=" | "**=" | "//=" -?testlist_star_expr: test_or_star_expr - | test_or_star_expr ("," test_or_star_expr)+ ","? -> tuple - | test_or_star_expr "," -> tuple - -// For normal and annotated assignments, additional restrictions enforced by the interpreter -del_stmt: "del" exprlist -pass_stmt: "pass" -?flow_stmt: break_stmt | continue_stmt | return_stmt | raise_stmt | yield_stmt -break_stmt: "break" -continue_stmt: "continue" -return_stmt: "return" [testlist] -yield_stmt: yield_expr -raise_stmt: "raise" [test ["from" test]] -import_stmt: import_name | import_from -import_name: "import" dotted_as_names -// note below: the ("." | "...") is necessary because "..." is tokenized as ELLIPSIS -import_from: "from" (dots? dotted_name | dots) "import" ("*" | "(" import_as_names ")" | import_as_names) -!dots: "."+ -import_as_name: name ["as" name] -dotted_as_name: dotted_name ["as" name] -import_as_names: import_as_name ("," import_as_name)* [","] -dotted_as_names: dotted_as_name ("," dotted_as_name)* -dotted_name: name ("." name)* -global_stmt: "global" name ("," name)* -nonlocal_stmt: "nonlocal" name ("," name)* -assert_stmt: "assert" test ["," test] - -?compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | match_stmt - | with_stmt | funcdef | classdef | decorated | async_stmt -async_stmt: "async" (funcdef | with_stmt | for_stmt) -if_stmt: "if" test ":" suite elifs ["else" ":" suite] -elifs: elif_* -elif_: "elif" test ":" suite -while_stmt: "while" test ":" suite ["else" ":" suite] -for_stmt: "for" exprlist "in" testlist ":" suite ["else" ":" suite] -try_stmt: "try" ":" suite except_clauses ["else" ":" suite] [finally] - | "try" ":" suite finally -> try_finally -finally: "finally" ":" suite -except_clauses: except_clause+ -except_clause: "except" [test ["as" name]] ":" suite -// NB compile.c makes sure that the default except clause is last - - -with_stmt: "with" with_items ":" suite -with_items: with_item ("," with_item)* -with_item: test ["as" name] - -match_stmt: "match" test ":" _NEWLINE _INDENT case+ _DEDENT - -case: "case" pattern ["if" test] ":" suite - -?pattern: sequence_item_pattern "," _sequence_pattern -> sequence_pattern - | as_pattern -?as_pattern: or_pattern ("as" NAME)? -?or_pattern: closed_pattern ("|" closed_pattern)* -?closed_pattern: literal_pattern - | NAME -> capture_pattern - | "_" -> any_pattern - | attr_pattern - | "(" as_pattern ")" - | "[" _sequence_pattern "]" -> sequence_pattern - | "(" (sequence_item_pattern "," _sequence_pattern)? ")" -> sequence_pattern - | "{" (mapping_item_pattern ("," mapping_item_pattern)* ","?)?"}" -> mapping_pattern - | "{" (mapping_item_pattern ("," mapping_item_pattern)* ",")? "**" NAME ","? "}" -> mapping_star_pattern - | class_pattern - -literal_pattern: inner_literal_pattern - -?inner_literal_pattern: "None" -> const_none - | "True" -> const_true - | "False" -> const_false - | STRING -> string - | number - -attr_pattern: NAME ("." NAME)+ -> value - -name_or_attr_pattern: NAME ("." NAME)* -> value - -mapping_item_pattern: (literal_pattern|attr_pattern) ":" as_pattern - -_sequence_pattern: (sequence_item_pattern ("," sequence_item_pattern)* ","?)? -?sequence_item_pattern: as_pattern - | "*" NAME -> star_pattern - -class_pattern: name_or_attr_pattern "(" [arguments_pattern ","?] ")" -arguments_pattern: pos_arg_pattern ["," keyws_arg_pattern] - | keyws_arg_pattern -> no_pos_arguments - -pos_arg_pattern: as_pattern ("," as_pattern)* -keyws_arg_pattern: keyw_arg_pattern ("," keyw_arg_pattern)* -keyw_arg_pattern: NAME "=" as_pattern - - - -suite: simple_stmt | _NEWLINE _INDENT stmt+ _DEDENT - -?test: or_test ("if" or_test "else" test)? - | lambdef - | assign_expr - -assign_expr: name ":=" test - -?test_nocond: or_test | lambdef_nocond - -?or_test: and_test ("or" and_test)* -?and_test: not_test_ ("and" not_test_)* -?not_test_: "not" not_test_ -> not_test - | comparison -?comparison: expr (comp_op expr)* -star_expr: "*" expr - -?expr: or_expr -?or_expr: xor_expr ("|" xor_expr)* -?xor_expr: and_expr ("^" and_expr)* -?and_expr: shift_expr ("&" shift_expr)* -?shift_expr: arith_expr (_shift_op arith_expr)* -?arith_expr: term (_add_op term)* -?term: factor (_mul_op factor)* -?factor: _unary_op factor | power - -!_unary_op: "+"|"-"|"~" -!_add_op: "+"|"-" -!_shift_op: "<<"|">>" -!_mul_op: "*"|"@"|"/"|"%"|"//" -// <> isn't actually a valid comparison operator in Python. It's here for the -// sake of a __future__ import described in PEP 401 (which really works :-) -!comp_op: "<"|">"|"=="|">="|"<="|"<>"|"!="|"in"|"not" "in"|"is"|"is" "not" - -?power: await_expr ("**" factor)? -?await_expr: AWAIT? atom_expr -AWAIT: "await" - -?atom_expr: atom_expr "(" [arguments] ")" -> funccall - | atom_expr "[" subscriptlist "]" -> getitem - | atom_expr "." name -> getattr - | atom - -?atom: "(" yield_expr ")" - | "(" _tuple_inner? ")" -> tuple - | "(" comprehension{test_or_star_expr} ")" -> tuple_comprehension - | "[" _exprlist? "]" -> list - | "[" comprehension{test_or_star_expr} "]" -> list_comprehension - | "{" _dict_exprlist? "}" -> dict - | "{" comprehension{key_value} "}" -> dict_comprehension - | "{" _exprlist "}" -> set - | "{" comprehension{test} "}" -> set_comprehension - | name -> var - | number - | string_concat - | "(" test ")" - | "..." -> ellipsis - | "None" -> const_none - | "True" -> const_true - | "False" -> const_false - - -?string_concat: string+ - -_tuple_inner: test_or_star_expr (("," test_or_star_expr)+ [","] | ",") - -?test_or_star_expr: test - | star_expr - -?subscriptlist: subscript - | subscript (("," subscript)+ [","] | ",") -> subscript_tuple -?subscript: test | ([test] ":" [test] [sliceop]) -> slice -sliceop: ":" [test] -?exprlist: (expr|star_expr) - | (expr|star_expr) (("," (expr|star_expr))+ [","]|",") -?testlist: test | testlist_tuple -testlist_tuple: test (("," test)+ [","] | ",") -_dict_exprlist: (key_value | "**" expr) ("," (key_value | "**" expr))* [","] - -key_value: test ":" test - -_exprlist: test_or_star_expr ("," test_or_star_expr)* [","] - -classdef: "class" name ["(" [arguments] ")"] ":" suite - - - -arguments: argvalue ("," argvalue)* ("," [ starargs | kwargs])? - | starargs - | kwargs - | comprehension{test} - -starargs: stararg ("," stararg)* ("," argvalue)* ["," kwargs] -stararg: "*" test -kwargs: "**" test ("," argvalue)* - -?argvalue: test ("=" test)? - - -comprehension{comp_result}: comp_result comp_fors [comp_if] -comp_fors: comp_for+ -comp_for: [ASYNC] "for" exprlist "in" or_test -ASYNC: "async" -?comp_if: "if" test_nocond - -// not used in grammar, but may appear in "node" passed from Parser to Compiler -encoding_decl: name - -yield_expr: "yield" [testlist] - | "yield" "from" test -> yield_from - -number: DEC_NUMBER | HEX_NUMBER | BIN_NUMBER | OCT_NUMBER | FLOAT_NUMBER | IMAG_NUMBER -string: STRING // | LONG_STRING - -// Other terminals - -_NEWLINE: ( /\r?\n[\t ]*/ | COMMENT )+ - -%ignore /[\t \f]+/ // WS -%ignore /\\[\t \f]*\r?\n/ // LINE_CONT -%ignore COMMENT -%declare _INDENT _DEDENT - - -// Python terminals - -!name: NAME | "match" | "case" -NAME: /[^\W\d]\w*/ -COMMENT: /#[^\n]*/ - -// We only need a usable approximation for something like this until we fully -// implement look-arounds, or use a regex/FSM library that supports them. -// STRING: /([ubf]?r?|r[ubf])("(?!"").*?(? BetterFSM: + new_fsm = make_byte_level_fsm(fsm, keep_utf8) + return BetterFSM( + alphabet=BetterAlphabet(new_fsm.alphabet._symbol_mapping), + states=new_fsm.states, + initial=new_fsm.initial, + finals=new_fsm.finals, + map=new_fsm.map, + ) + + +def test_walk_fsm(): regex_pattern = interegular.parse_pattern("0|[1-9][2-9]*") regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - res = tuple(function(regex_fsm, "0", regex_fsm.initial, full_match=True)) + res = tuple( + walk_fsm_from_token_str_rust(regex_fsm, "0", regex_fsm.initial, full_match=True) + ) assert res == (1,) - res = tuple(function(regex_fsm, "00", regex_fsm.initial, full_match=False)) + res = tuple( + walk_fsm_from_token_str_rust( + regex_fsm, "00", regex_fsm.initial, full_match=False + ) + ) assert res == (1,) - res = tuple(function(regex_fsm, "!", regex_fsm.initial, full_match=True)) + res = tuple( + walk_fsm_from_token_str_rust(regex_fsm, "!", regex_fsm.initial, full_match=True) + ) assert res == tuple() - res = tuple(function(regex_fsm, "00", regex_fsm.initial, full_match=True)) + res = tuple( + walk_fsm_from_token_str_rust( + regex_fsm, "00", regex_fsm.initial, full_match=True + ) + ) assert res == tuple() # This should fail, because state `1` reads nothing - res = tuple(function(regex_fsm, "0", 1, full_match=True)) + res = tuple(walk_fsm_from_token_str_rust(regex_fsm, "0", 1, full_match=True)) assert res == tuple() regex_pattern = interegular.parse_pattern("0|[1-9][2-9]+") regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - res = tuple(function(regex_fsm, "1", regex_fsm.initial, full_match=True)) + res = tuple( + walk_fsm_from_token_str_rust(regex_fsm, "1", regex_fsm.initial, full_match=True) + ) assert res == tuple() - res = tuple(function(regex_fsm, "1", regex_fsm.initial, full_match=False)) + res = tuple( + walk_fsm_from_token_str_rust( + regex_fsm, "1", regex_fsm.initial, full_match=False + ) + ) assert res == (2,) - res = tuple(function(regex_fsm, "12", regex_fsm.initial, full_match=True)) + res = tuple( + walk_fsm_from_token_str_rust( + regex_fsm, "12", regex_fsm.initial, full_match=True + ) + ) assert res == (2, 3) pattern = interegular.parse_pattern(r"(?:[^\W\d]\w*|[\t \x0c]+)") fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce()) - res = tuple(function(fsm, "x ", fsm.initial, full_match=False)) + res = tuple(walk_fsm_from_token_str_rust(fsm, "x ", fsm.initial, full_match=False)) assert res == (2,) start_state = list(fsm.finals)[0] - res = tuple(function(fsm, "!", start_state, full_match=False)) + res = tuple(walk_fsm_from_token_str_rust(fsm, "!", start_state, full_match=False)) assert res == tuple() -@pytest.mark.parametrize( - "function", - [ - walk_fsm_from_token_str, - walk_fsm_from_token_str_numba, - ], -) @pytest.mark.parametrize( "transform", [ @@ -135,20 +137,20 @@ def test_walk_fsm(function): to_bytes, ], ) -def test_walk_fsm_multi_bytes(function, transform): +def test_walk_fsm_multi_bytes(transform): regex_pattern = interegular.parse_pattern("πŸ˜‚|[πŸ˜‡-😍][😈-😍]*") str_regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) regex_fsm = make_byte_level_better_fsm(str_regex_fsm, keep_utf8=True) res = tuple( - function( + walk_fsm_from_token_str_rust( regex_fsm, merge_symbols(transform("πŸ˜‚")), regex_fsm.initial, full_match=True ) ) assert res[-1:] == (1,) res = tuple( - function( + walk_fsm_from_token_str_rust( regex_fsm, merge_symbols(transform("πŸ˜‚πŸ˜‚")), regex_fsm.initial, @@ -158,14 +160,14 @@ def test_walk_fsm_multi_bytes(function, transform): assert res[-1:] == (1,) res = tuple( - function( + walk_fsm_from_token_str_rust( regex_fsm, merge_symbols(transform("!")), regex_fsm.initial, full_match=True ) ) assert res == tuple() res = tuple( - function( + walk_fsm_from_token_str_rust( regex_fsm, merge_symbols(transform("πŸ˜‚πŸ˜‚")), regex_fsm.initial, @@ -175,161 +177,6 @@ def test_walk_fsm_multi_bytes(function, transform): assert res == tuple() -def test_get_sub_fsms_from_seq(): - name_pattern = interegular.parse_pattern(r"[^\W\d]\w*") - name_fsm, _ = make_deterministic_fsm(name_pattern.to_fsm().reduce()) - - def_pattern = interegular.parse_pattern("def") - def_fsm, _ = make_deterministic_fsm(def_pattern.to_fsm().reduce()) - - match_pattern = interegular.parse_pattern("match") - match_fsm, _ = make_deterministic_fsm(match_pattern.to_fsm().reduce()) - - peq_pattern = interegular.parse_pattern(r"\+=") - peq_fsm, _ = make_deterministic_fsm(peq_pattern.to_fsm().reduce()) - - plus_pattern = interegular.parse_pattern(r"\+") - plus_fsm, _ = make_deterministic_fsm(plus_pattern.to_fsm().reduce()) - - fsms = [def_fsm, match_fsm, name_fsm, peq_fsm, plus_fsm] - - fsm, fsms_to_trans_finals = fsm_union(fsms) - - assert fsms_to_trans_finals == { - 0: ({(0, 3), (3, 9), (9, 10)}, {10}, {0: {0}, 1: {3}, 2: {9}, 3: {10}}), - 1: ( - {(0, 4), (4, 5), (5, 6), (6, 7), (7, 8)}, - {8}, - {0: {0}, 1: {4}, 2: {5}, 3: {6}, 4: {7}, 5: {8}}, - ), - 2: ( - { - (0, 2), - (0, 3), - (0, 4), - (2, 2), - (3, 2), - (3, 9), - (4, 2), - (4, 5), - (5, 2), - (5, 6), - (6, 2), - (6, 7), - (7, 2), - (7, 8), - (8, 2), - (9, 2), - (9, 10), - (10, 2), - }, - {2, 3, 4, 5, 6, 7, 8, 9, 10}, - {0: {0}, 1: {2, 3, 4, 5, 6, 7, 8, 9, 10}}, - ), - 3: ({(0, 1), (1, 11)}, {11}, {0: {0}, 1: {1}, 2: {11}}), - 4: ({(0, 1)}, {1}, {0: {0}, 1: {1}}), - } - - assert not fsm.accepts("1a") - assert fsm.accepts("a1") - assert fsm.accepts("def") - assert fsm.accepts("match") - assert fsm.accepts("+=") - assert fsm.accepts("+") - - state_seq = walk_fsm_from_token_str(fsm, "def", fsm.initial) - state_seq.insert(0, fsm.fsm_info.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, False, True), (2, True, True)] - - # Make sure the old-to-new state map is correct - def_state_seq = walk_fsm_from_token_str(def_fsm, "def", fsm.initial) - def_state_seq.insert(0, fsm.fsm_info.initial) - - def_old_to_new_states = fsms_to_trans_finals[0][2] - assert all( - new_state in def_old_to_new_states[old_state] - for old_state, new_state in zip(def_state_seq, state_seq) - ) - - state_seq = walk_fsm_from_token_str(fsm, "ef", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(2, True, True)] - - name_state_seq = walk_fsm_from_token_str(name_fsm, "ef", fsm.initial) - name_state_seq.insert(0, fsm.initial) - - name_old_to_new_states = fsms_to_trans_finals[2][2] - assert all( - new_state in name_old_to_new_states[old_state] - for old_state, new_state in zip(name_state_seq, state_seq) - ) - - state_seq = walk_fsm_from_token_str(fsm, "match", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(1, False, True), (2, True, True)] - - match_state_seq = walk_fsm_from_token_str(match_fsm, "match", fsm.initial) - match_state_seq.insert(0, fsm.initial) - - match_old_to_new_states = fsms_to_trans_finals[1][2] - assert all( - new_state in match_old_to_new_states[old_state] - for old_state, new_state in zip(match_state_seq, state_seq) - ) - - state_seq = walk_fsm_from_token_str(fsm, "defa", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(2, True, True)] - - state_seq = walk_fsm_from_token_str(fsm, "de", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, True, False), (2, True, True)] - - state_seq = walk_fsm_from_token_str(fsm, "+", fsm.initial, False) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(3, True, False), (4, False, True)] - - state_seq = walk_fsm_from_token_str(fsm, "+=", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(3, False, True)] - - # Test some overlapping patterns - join_fsms = [ - interegular.parse_pattern(r"JOIN").to_fsm().reduce(), - interegular.parse_pattern(r"JOIN LEFT").to_fsm().reduce(), - ] - fsm, fsms_to_trans_finals = fsm_union(join_fsms) - - # Matching "OI" - state_seq = [1, 2, 3] - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, True, False), (1, True, False)] - - # Matching "N" - state_seq = [3, 4] - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, False, True), (1, True, False)] - - # Matching " " - state_seq = [4, 5] - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(1, True, False)] - - def test_create_fsm_index_end_to_end(): regex_str = "0|[1-9][0-9]*" @@ -337,29 +184,26 @@ def test_create_fsm_index_end_to_end(): regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) vocabulary = { - "blah": numba.typed.List([0]), - "1a": numba.typed.List([1]), - "2": numba.typed.List([2]), - "0": numba.typed.List([3]), - "": numba.typed.List([4]), + "blah": [0], + "1a": [1], + "2": [2], + "0": [3], + "": [4], } - vocabulary_nb = numba.typed.List.empty_list( - numba.types.Tuple( - ( - numba.types.unicode_type, - numba.int64[:], - ) - ) - ) + vocabulary_nb = [] for token_tuple, token_ids in vocabulary.items(): token = merge_symbols(token_tuple) token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) vocabulary_nb.append((token, token_ids_np)) - res = create_fsm_index_end_to_end(regex_fsm.fsm_info, vocabulary_nb) + res = create_fsm_index_end_to_end( + regex_fsm.fsm_info, + vocabulary_nb, + frozenset(), + ) - assert res == {0: {(2, 2), (3, 1)}, 2: {(2, 2), (3, 2)}} + assert res == {0: {2: 2, 3: 1}, 2: {2: 2, 3: 2}} def test_create_fsm_index_end_to_end_multi_byte(): @@ -370,35 +214,30 @@ def test_create_fsm_index_end_to_end_multi_byte(): byte_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) vocabulary = { - "blah": numba.typed.List([0]), - "😈a": numba.typed.List([1]), - "πŸ˜‡": numba.typed.List([2]), - "😍": numba.typed.List([3]), - merge_symbols(("F0", "9F", "98", "8D")): numba.typed.List([4]), # '😍' - " 😍": numba.typed.List([5]), - merge_symbols((" ", "F0", "9F", "98", "8D")): numba.typed.List([6]), # ' 😍' - merge_symbols((" ", "F0", "9F", "98")): numba.typed.List( - [7] - ), # ' 😍' incomplete - "": numba.typed.List([8]), + "blah": [0], + "😈a": [1], + "πŸ˜‡": [2], + "😍": [3], + merge_symbols(("F0", "9F", "98", "8D")): [4], # '😍' + " 😍": [5], + merge_symbols((" ", "F0", "9F", "98", "8D")): [6], # ' 😍' + merge_symbols((" ", "F0", "9F", "98")): [7], # ' 😍' incomplete + "": [8], } - vocabulary_nb = numba.typed.List.empty_list( - numba.types.Tuple( - ( - numba.types.unicode_type, - numba.int64[:], - ) - ) - ) + vocabulary_nb = [] for token_tuple, token_ids in vocabulary.items(): token_tuple_np = merge_symbols(token_tuple) token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) vocabulary_nb.append((token_tuple_np, token_ids_np)) - res = create_fsm_index_end_to_end(byte_fsm.fsm_info, vocabulary_nb) + res = create_fsm_index_end_to_end( + byte_fsm.fsm_info, + vocabulary_nb, + frozenset(), + ) - assert res == {0: {(5, 3), (6, 3), (7, 7), (2, 2)}, 3: {(2, 3), (3, 3), (4, 3)}} + assert res == {0: {5: 3, 6: 3, 7: 7, 2: 2}, 3: {2: 3, 3: 3, 4: 3}} @pytest.mark.parametrize( @@ -510,7 +349,6 @@ def test_regex_index_performance(): tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer = TransformerTokenizer(tokenizer) - # Pre-compile Numba functions res, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer) assert len(res) > 1 @@ -597,21 +435,19 @@ def convert_token_to_string(self, token): token_trans_keys = get_vocabulary_transition_keys( regex_fsm.fsm_info.alphabet_symbol_mapping, regex_fsm.fsm_info.alphabet_anything_value, - vocabulary, - numba.typed.List.empty_list(numba.types.unicode_type), + list(vocabulary.items()), + frozenset(), ) token_str_to_tranition_keys = { token_str: trans_key_seq - for (token_str, _), trans_key_seq in zip(vocabulary, token_trans_keys) + for (token_str, _), trans_key_seq in zip(vocabulary.items(), token_trans_keys) } # `a` and `b` both are workable, but `z` has distinct transition rules assert interegular_fsm.accepts("zaz") assert interegular_fsm.accepts("zbz") - assert (token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["b"]).all() - assert not ( - token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["z"] - ).all() + assert token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["b"] + assert not token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["z"] def test_token_trans_keys_walk_fsm(): @@ -635,13 +471,13 @@ def convert_token_to_string(self, token): token_trans_keys = get_vocabulary_transition_keys( regex_fsm.fsm_info.alphabet_symbol_mapping, regex_fsm.fsm_info.alphabet_anything_value, - vocabulary, - numba.typed.List.empty_list(numba.types.unicode_type), + list(vocabulary.items()), + frozenset(), ) token_str_trans_key_seq = { token_str: trans_key_seq - for (token_str, _), trans_key_seq in zip(vocabulary, token_trans_keys) + for (token_str, _), trans_key_seq in zip(vocabulary.items(), token_trans_keys) } # verify initial state valid only for "ab" and "ac" using transition key seq @@ -653,42 +489,13 @@ def convert_token_to_string(self, token): regex_fsm.fsm_info.initial, regex_fsm.fsm_info.finals, token_trans_key_seq, - regex_fsm.fsm_info.initial, + regex_fsm.initial, False, ) is_accepted = len(state_seq) >= len(token_trans_key_seq) assert should_accept == is_accepted -def test_numba_leading_null_byte_UnicodeCharSeq_remains_broken(): - """Assert numba UnicodeCharSeq w/ leading \x00 is still broken""" - # EXPLANATION: - # https://github.com/outlines_core-dev/outlines/pull/930#issuecomment-2143535968 - - # from https://github.com/numba/numba/issues/9542 - d = numba.typed.typeddict.Dict.empty(numba.types.UnicodeCharSeq(1), numba.int64) - d["δΈ€"] = 10 # \xe4\xb8\x80 - with pytest.raises(KeyError): - str(d) - - # most characters are fine, but "\x00" is converted to "" - l = np.fromiter(["\x99", "\x00"], dtype=np.dtype("U2")) - assert str(l[0]) == "\x99" # fine - assert str(l[1]) == "" # 1-byte null converted to 0-bytes - - -@pytest.mark.parametrize("input_key", ["δΈ€", "\x00"]) -def test_numba_leading_null_byte_unicode_type_sane(input_key): - """Assert numba unicode_type w/ leading \x00 is working""" - # EXPLANATION: - # https://github.com/outlines_core-dev/outlines/pull/930#issuecomment-2143535968 - - # from https://github.com/numba/numba/issues/9542 - d = numba.typed.typeddict.Dict.empty(numba.types.unicode_type, numba.int64) - d["δΈ€"] = 10 # \xe4\xb8\x80 - str(d) # assert successfully interprets - - @pytest.mark.parametrize( "rare_token", [ @@ -714,4 +521,4 @@ def test_reduced_vocabulary_with_rare_tokens(rare_token): tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") tokenizer = adapt_tokenizer(tokenizer=tokenizer) tokenizer.vocabulary[rare_token] = max(tokenizer.vocabulary.values()) + 1 - reduced_vocabulary(tokenizer) + reduced_vocabulary(tokenizer) \ No newline at end of file diff --git a/justfile b/justfile index e2588cab..959ace76 100644 --- a/justfile +++ b/justfile @@ -5,12 +5,7 @@ build-core: cd outlines-core && cargo build --release dev-python: - cd bindings/python && pip install -e . + cd bindings/python && maturin develop build-python: - cd bindings/python && \ - ln -sf ../../outlines-core outlines-core-lib && \ - sed -i '' 's|path = "../../outlines-core"|path = "outlines-core-lib"|' Cargo.toml && \ - python -m build && \ - rm outlines-core-lib && \ - sed -i '' 's|path = "outlines-core-lib"|path = "../../outlines-core"|' Cargo.toml + cd bindings/python && maturin build --release diff --git a/outlines-core/Cargo.lock b/outlines-core/Cargo.lock deleted file mode 100644 index 36977b45..00000000 --- a/outlines-core/Cargo.lock +++ /dev/null @@ -1,7 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "outlines-core" -version = "0.1.0" diff --git a/outlines-core/src/lib.rs b/outlines-core/src/lib.rs index 8957cb2b..5380c919 100644 --- a/outlines-core/src/lib.rs +++ b/outlines-core/src/lib.rs @@ -1,4 +1 @@ - -pub fn hello() -> String { - "world ".to_string() -} +pub mod regex; \ No newline at end of file diff --git a/outlines-core/src/regex.rs b/outlines-core/src/regex.rs new file mode 100644 index 00000000..1e60e57c --- /dev/null +++ b/outlines-core/src/regex.rs @@ -0,0 +1,143 @@ +use std::collections::{HashMap, HashSet}; + +pub fn walk_fsm_internal( + fsm_transitions: &HashMap<(u32, u32), u32>, + _fsm_initial: u32, + fsm_finals: &HashSet, + token_transition_keys: &[u32], + start_state: u32, + full_match: bool, +) -> Vec { + let mut state = start_state; + let mut accepted_states = Vec::new(); + let mut last_final_idx = 0; + + for (i, &trans_key) in token_transition_keys.iter().enumerate() { + match fsm_transitions.get(&(state, trans_key)) { + Some(&new_state) => { + state = new_state; + if fsm_finals.contains(&state) { + last_final_idx = i + 1; + } + accepted_states.push(state); + } + None => { + if !full_match && last_final_idx > 0 { + return accepted_states[..last_final_idx].to_vec(); + } + return Vec::new(); + } + } + } + + if full_match && last_final_idx != token_transition_keys.len() { + return Vec::new(); + } + + accepted_states +} + +pub fn state_scan_tokens_internal( + fsm_transitions: &HashMap<(u32, u32), u32>, + fsm_initial: u32, + fsm_finals: &HashSet, + vocabulary: &[(String, Vec)], + vocabulary_transition_keys: &[Vec], + start_state: u32, +) -> HashSet<(u32, u32)> { + let mut res = HashSet::new(); + + for (vocab_item, token_transition_keys) in + vocabulary.iter().zip(vocabulary_transition_keys.iter()) + { + let token_ids: Vec = vocab_item.1.clone(); + + let state_seq = walk_fsm_internal( + fsm_transitions, + fsm_initial, + fsm_finals, + token_transition_keys, + start_state, + false, + ); + + if state_seq.len() < token_transition_keys.len() { + continue; + } + + for &token_id in &token_ids { + res.insert((token_id, *state_seq.last().unwrap())); + } + } + + res +} + +pub fn get_token_transition_keys_internal( + alphabet_symbol_mapping: &HashMap, + alphabet_anything_value: u32, + token_str: &str, +) -> Vec { + let mut token_transition_keys = Vec::new(); + let mut i = 0; + let chars: Vec = token_str.chars().collect(); + + while i < chars.len() { + let symbol; + if chars[i] == '\0' && i != chars.len() - 1 { + if i + 2 < chars.len() { + symbol = format!("\0{}{}", chars[i + 1], chars[i + 2]); + i += 3; + } else { + symbol = chars[i].to_string(); + i += 1; + } + } else { + symbol = chars[i].to_string(); + i += 1; + } + + let transition_key = *alphabet_symbol_mapping + .get(&symbol) + .unwrap_or(&alphabet_anything_value); + token_transition_keys.push(transition_key); + } + + token_transition_keys +} + +pub fn get_vocabulary_transition_keys_internal( + alphabet_symbol_mapping: &HashMap, + alphabet_anything_value: u32, + vocabulary: &[(String, Vec)], + frozen_tokens: &HashSet, +) -> Vec> { + let mut vocab_transition_keys: Vec> = Vec::new(); + + for item in vocabulary.iter() { + let token_str = item.0.clone(); + + let mut token_transition_keys; + + // Since these tokens are not expanded into byte-level transitions, we + // can simply get their transition keys directly. + if frozen_tokens.contains(&token_str) { + token_transition_keys = Vec::new(); + token_transition_keys.push( + *alphabet_symbol_mapping + .get(&token_str) + .unwrap_or(&alphabet_anything_value), + ) + } else { + token_transition_keys = get_token_transition_keys_internal( + alphabet_symbol_mapping, + alphabet_anything_value, + &token_str, + ); + } + + vocab_transition_keys.push(token_transition_keys); + } + + vocab_transition_keys +} diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index a2db8a08..00000000 --- a/pyproject.toml +++ /dev/null @@ -1,141 +0,0 @@ -[build-system] -requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2", "setuptools-rust"] -build-backend = "setuptools.build_meta" - -[project] -name = "outlines_core" -authors= [{name = "Outlines Developers"}] -description = "Structured Text Generation in Rust" -requires-python = ">=3.8" -license = {text = "Apache-2.0"} -keywords=[ - "machine learning", - "deep learning", - "language models", - "structured generation", -] -classifiers = [ - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Developers", - "Intended Audience :: Information Technology", - "Intended Audience :: Science/Research", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3", - "Topic :: Scientific/Engineering :: Artificial Intelligence", -] -dependencies = [ - "interegular", - "numpy<2.0.0", - "cloudpickle", - "diskcache", - "pydantic>=2.0", - "referencing", - "jsonschema", - "tqdm", - "datasets", - "typing_extensions", -] -dynamic = ["version"] - -[project.optional-dependencies] -test = [ - "pre-commit", - "pytest", - "pytest-benchmark", - "pytest-cov", - "pytest-mock", - "coverage[toml]>=5.1", - "diff-cover", - "accelerate", - "beartype<0.16.0", - "huggingface_hub", - "torch", - "transformers", - "pillow", -] - -[project.urls] -homepage = "https://github.com/outlines-dev/outlines-core" -documentation = "https://outlines-dev.github.io/outlines-core/" -repository = "https://github.com/outlines-dev/outlines-core/" - -[project.readme] -file="README.md" -content-type = "text/markdown" - -[tool.setuptools] -packages = ["outlines_core"] -package-dir = {"" = "python"} - -[tool.setuptools.package-data] -"outlines" = ["py.typed"] - -[tool.setuptools_scm] -write_to = "python/outlines_core/_version.py" - -[tool.pytest.ini_options] -testpaths = ["tests"] -filterwarnings = [ - "error", - "ignore::pydantic.warnings.PydanticDeprecatedSince20", - "ignore::FutureWarning:transformers.*", - "ignore::FutureWarning:huggingface_hub.*", - "ignore::UserWarning", -] -addopts = [ - "--import-mode=importlib" -] - -[tool.mypy] -exclude=["examples"] -enable_incomplete_feature = ["Unpack"] - -[[tool.mypy.overrides]] -module = [ - "jsonschema.*", - "numpy.*", - "cloudpickle.*", - "diskcache.*", - "pydantic.*", - "pytest", - "referencing.*", - "torch.*", - "transformers.*", - "huggingface_hub", - "interegular.*", - "datasets.*", - "setuptools.*", - "setuptools_rust.*", - # TODO: Add type info for the Rust extension - "outlines_core.fsm.outlines_core_rs.*", -] -ignore_missing_imports = true - -[tool.coverage.run] -omit = [ - "python/outlines_core/_version.py", - "tests/*", -] -branch = true - -[tool.coverage.report] -omit = [ - "tests/*", -] -exclude_lines = [ - "pragma: no cover", - "if TYPE_CHECKING:", - "\\.\\.\\.", -] -show_missing = true - -[tool.coverage.paths] -source = [ - "outlines_core", - "**/site-packages/outlines_core", -] - - -[tool.diff_cover] -compare_branch = "origin/main" -diff_range_notation = ".." diff --git a/python/outlines_core/fsm/__init__.py b/python/outlines_core/fsm/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/python/outlines_core/fsm/guide.py b/python/outlines_core/fsm/guide.py deleted file mode 100644 index 8f7250ef..00000000 --- a/python/outlines_core/fsm/guide.py +++ /dev/null @@ -1,295 +0,0 @@ -from dataclasses import dataclass -from typing import ( - TYPE_CHECKING, - Callable, - Dict, - List, - Optional, - Protocol, - Set, - Tuple, - Union, -) - -import interegular -import torch -from outlines_core.fsm.regex import ( - create_fsm_index_tokenizer, - make_byte_level_fsm, - make_deterministic_fsm, -) - -if TYPE_CHECKING: - from outlines_core.models.tokenizer import Tokenizer - - -@dataclass(frozen=True) -class Write: - """Write instruction. - - Attributes - ---------- - tokens - The sequence of tokens to be added to the current sequence by the - generation process. - - """ - - tokens: List[int] - - -@dataclass(frozen=True) -class Generate: - """Generate instruction - - Attributes - ---------- - tokens - The tokens that lead to a valid completion if generated. A value - of ``None`` indicates that all tokens are allowed. - """ - - tokens: Optional[List[int]] - - -Instruction = Union[Write, Generate] - - -class Guide(Protocol): - """Base definition of a generation guide. - - A generation guide defines the behavior of a finite-state machine that guides - a text generation procedure. Unlike the DFAs built from regular expressions - guides can also emit a `Write` instructions which tells the model that it can - append a sequence of tokens (or token word) instead of generating it. - - """ - - def get_next_instruction(self, state: int) -> Instruction: - ... - - def get_next_state(self, state: int, token_id: int) -> int: - ... - - def is_final_state(self, state: int) -> bool: - ... - - def copy(self) -> "Guide": - ... - - -class StopAtEOSGuide(Guide): - """Guide to generate tokens until the EOS token has been generated.""" - - final_state = 1 - start_state = 0 - - def __init__(self, tokenizer: "Tokenizer"): - """Initialize the generation guide. - - model - The logit generator used to generate the next token. - - """ - self.eos_token_id = tokenizer.eos_token_id - self.vocabulary = tokenizer.vocabulary.values() - - def get_next_instruction(self, state: int) -> Instruction: - if self.is_final_state(state): - return Write([self.eos_token_id]) - return Generate(None) - - def get_next_state(self, state: int, token_id: int) -> int: - if token_id == self.eos_token_id or state == self.final_state: - return self.final_state - - return self.start_state - - def is_final_state(self, state: int): - return state == self.final_state - - def copy(self): - return self - - -def create_states_mapping( - regex_string: str, - tokenizer: "Tokenizer", - regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern, - frozen_tokens: List[str] = [], -) -> Tuple[Dict[int, Dict[int, int]], Set[int], set]: - """Create the variables related to the mapping between states and tokens - The parameters of the function are used for caching purpose. - - Parameters - ---------- - regex_string: (`str`): - The regular expression string to generate a states mapping for. - tokenizer: (`Tokenizer`): - The model's tokenizer. - regex_parser: (`Callable[[str], interegular.Pattern]`, *optional*): - A function that parses a regex string into an `interegular` Pattern object. - frozen_tokens: (`List[str]`, *optional*): - A list of tokens that should be kept as-is when expanding the token-level FSM - into a byte-level FSM. Defaults to an empty list. - - Returns - ------- - states_to_token_maps: (`Dict[int, Dict[int, int]]`): - A mapping from states to a mapping from token ids originating from that state - to the next state to transition to given that token. The structure is as follows: - (origin_state -> (token_id -> next_state)) - empty_token_ids: (`Set[int]`): - A set of token ids that correspond to empty strings. - final_states: (`set`): - A set of final states in the FSM. - """ - regex_pattern = regex_parser(regex_string) - byte_fsm = make_byte_level_fsm( - regex_pattern.to_fsm().reduce(), keep_utf8=True, frozen_tokens=frozen_tokens - ) - regex_fsm, _ = make_deterministic_fsm(byte_fsm) - states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer, frozen_tokens=frozen_tokens - ) - - # We make sure that it is possible to generate strings in the language - # of the regular expression with the tokens present in the model's - # vocabulary. - if not any( - regex_fsm.finals.intersection(v.values()) for v in states_to_token_maps.values() - ): - raise ValueError( - "The vocabulary does not allow us to build a sequence that matches the input regex" - ) - - return states_to_token_maps, empty_token_ids, regex_fsm.finals - - -class RegexGuide(Guide): - """Guide to generate text in the language of a regular expression.""" - - initial_state = 0 - - def __init__(self, regex_string: str, tokenizer: "Tokenizer"): - ( - self.states_to_token_maps, - self.empty_token_ids, - fsm_finals, - ) = create_states_mapping(regex_string, tokenizer) - self.eos_token_id = tokenizer.eos_token_id - self.final_states = fsm_finals | {-1} - self._cache_state_to_token_tensor() - - def get_next_instruction(self, state: int) -> Instruction: - """Return the next instruction for guided generation. - - The initialization of the guide builds an index which maps FSM states to a - map from authorized tokens to the state in which the guide needs to move - if said token is generated. Therefore the authorized tokens at the - current state are the keys of the map returned by the value of the index - for current state. - - If the current state is not contained in the end this means that we are - in a final state of the guide. We only authorize EOS tokens in the final - state. - - Parameters - ---------- - state - The current state of the guide. - - Returns - ------- - A `Generate` instance that contains the model and the allowed token ids. - - """ - next_tokens_mask = self.states_to_token_mask.get(state) - if next_tokens_mask is None: - return Write(torch.tensor([self.eos_token_id])) - - return Generate(next_tokens_mask) - - def get_next_state(self, state: int, token_id: int) -> int: - """Update the state of the guide. - - We use the index to determine to which state the guide should transition - given the token that was just generated. - - Parameters - ---------- - state - The current state of the guide. - token_id - The id of the token that was just generated. - - Returns - ------- - The new state of the guide. - - """ - if token_id == self.eos_token_id or state not in self.states_to_token_maps: - return -1 - - last_token_to_end_state = self.states_to_token_maps[state] - next_state = last_token_to_end_state.get(token_id) - if next_state is None: - next_state = -1 - - return next_state - - @classmethod - def from_interegular_fsm( - cls, interegular_fsm: interegular.fsm.FSM, tokenizer: "Tokenizer" - ): - from_interegular_instance = cls.__new__(cls) - - def create_states_mapping_from_interegular_fsm( - fsm: interegular.fsm.FSM, - ) -> Tuple[dict, set]: - """Create the variables related to the mapping between states and tokens - The parameters of the function are used for caching purpose - """ - byte_fsm = make_byte_level_fsm(fsm.reduce(), keep_utf8=True) - regex_fsm, _ = make_deterministic_fsm(byte_fsm) - states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer - ) - - # We make sure that it is possible to generate strings in the language - # of the regular expression with the tokens present in the model's - # vocabulary. - if not any( - regex_fsm.finals.intersection(v.values()) - for v in states_to_token_maps.values() - ): - raise ValueError( - "The vocabulary does not allow us to build a sequence that matches the input regex" - ) - - return states_to_token_maps, empty_token_ids - - ( - from_interegular_instance.states_to_token_maps, - from_interegular_instance.empty_token_ids, - ) = create_states_mapping_from_interegular_fsm(interegular_fsm) - from_interegular_instance.eos_token_id = tokenizer.eos_token_id - from_interegular_instance._cache_state_to_token_tensor() - return from_interegular_instance - - def _cache_state_to_token_tensor(self): - """ - cache state -> token int tensor - this increases performance of mask construction substantially - """ - self.states_to_token_mask = { - state: torch.tensor(list(next_tokens_to_end_states.keys())) - for state, next_tokens_to_end_states in self.states_to_token_maps.items() - } - - def is_final_state(self, state: int) -> bool: - """Determine whether the current state of the guide is a final state.""" - return state in self.final_states - - def copy(self): - return self diff --git a/python/outlines_core/fsm/json_schema.py b/python/outlines_core/fsm/json_schema.py deleted file mode 100644 index b2924300..00000000 --- a/python/outlines_core/fsm/json_schema.py +++ /dev/null @@ -1,519 +0,0 @@ -import inspect -import json -import re -import warnings -from typing import Callable, Optional, Tuple - -from jsonschema.protocols import Validator -from pydantic import create_model -from referencing import Registry, Resource -from referencing._core import Resolver -from referencing.jsonschema import DRAFT202012 - -# allow `\"`, `\\`, or any character which isn't a control sequence -STRING_INNER = r'([^"\\\x00-\x1F\x7F-\x9F]|\\["\\])' -STRING = f'"{STRING_INNER}*"' - -INTEGER = r"(-)?(0|[1-9][0-9]*)" -NUMBER = rf"({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?" -BOOLEAN = r"(true|false)" -NULL = r"null" -WHITESPACE = r"[ ]?" - -type_to_regex = { - "string": STRING, - "integer": INTEGER, - "number": NUMBER, - "boolean": BOOLEAN, - "null": NULL, -} - -DATE_TIME = r'"(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\.[0-9]{3})?(Z)?"' -DATE = r'"(?:\d{4})-(?:0[1-9]|1[0-2])-(?:0[1-9]|[1-2][0-9]|3[0-1])"' -TIME = r'"(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\\.[0-9]+)?(Z)?"' -UUID = r'"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"' - -format_to_regex = { - "uuid": UUID, - "date-time": DATE_TIME, - "date": DATE, - "time": TIME, -} - - -def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = None): - """Turn a JSON schema into a regex that matches any JSON object that follows - this schema. - - JSON Schema is a declarative language that allows to annotate JSON documents - with types and descriptions. These schemas can be generated from any Python - datastructure that has type annotation: namedtuples, dataclasses, Pydantic - models. And by ensuring that the generation respects the schema we ensure - that the output can be parsed into these objects. - This function parses the provided schema and builds a generation schedule which - mixes deterministic generation (fixed strings), and sampling with constraints. - - Parameters - ---------- - schema - A string that represents a JSON Schema. - whitespace_pattern - Pattern to use for JSON syntactic whitespace (doesn't impact string literals) - Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` - - Returns - ------- - A generation schedule. A list of strings that represent the JSON - schema's structure and regular expression that define the structure of - the fields. - - References - ---------- - .. [0] JSON Schema. https://json-schema.org/ - - """ - - schema = json.loads(schema) - Validator.check_schema(schema) - - # Build reference resolver - schema = Resource(contents=schema, specification=DRAFT202012) - uri = schema.id() if schema.id() is not None else "" - registry = Registry().with_resource(uri=uri, resource=schema) - resolver = registry.resolver() - - content = schema.contents - return to_regex(resolver, content, whitespace_pattern) - - -def _get_num_items_pattern(min_items, max_items, whitespace_pattern): - # Helper function for arrays and objects - min_items = int(min_items or 0) - if max_items is None: - return rf"{{{max(min_items - 1, 0)},}}" - else: - max_items = int(max_items) - if max_items < 1: - return None - return rf"{{{max(min_items - 1, 0)},{max_items - 1}}}" - - -def validate_quantifiers( - min_bound: Optional[str], max_bound: Optional[str], start_offset: int = 0 -) -> Tuple[str, str]: - """ - Ensures that the bounds of a number are valid. Bounds are used as quantifiers in the regex. - - Parameters - ---------- - min_bound - The minimum value that the number can take. - max_bound - The maximum value that the number can take. - start_offset - Number of elements that are already present in the regex but still need to be counted. - ex: if the regex is already "(-)?(0|[1-9][0-9])", we will always have at least 1 digit, so the start_offset is 1. - - Returns - ------- - min_bound - The minimum value that the number can take. - max_bound - The maximum value that the number can take. - - Raises - ------ - ValueError - If the minimum bound is greater than the maximum bound. - - TypeError or ValueError - If the minimum bound is not an integer or None. - or - If the maximum bound is not an integer or None. - """ - min_bound = "" if min_bound is None else str(int(min_bound) - start_offset) - max_bound = "" if max_bound is None else str(int(max_bound) - start_offset) - if min_bound and max_bound: - if int(max_bound) < int(min_bound): - raise ValueError("max bound must be greater than or equal to min bound") - return min_bound, max_bound - - -def to_regex( - resolver: Resolver, instance: dict, whitespace_pattern: Optional[str] = None -): - """Translate a JSON Schema instance into a regex that validates the schema. - - Note - ---- - Many features of JSON schema are missing: - - Handle `additionalProperties` keyword - - Handle types defined as a list - - Handle constraints on numbers - - Handle special patterns: `date`, `uri`, etc. - - This does not support recursive definitions. - - Parameters - ---------- - resolver - An object that resolves references to other instances within a schema - instance - The instance to translate - whitespace_pattern - Pattern to use for JSON syntactic whitespace (doesn't impact string literals) - Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` - """ - - # set whitespace pattern - if whitespace_pattern is None: - whitespace_pattern = WHITESPACE - - if instance == {}: - # JSON Schema Spec: Empty object means unconstrained, any json type is legal - types = [ - {"type": "boolean"}, - {"type": "null"}, - {"type": "number"}, - {"type": "integer"}, - {"type": "string"}, - {"type": "array"}, - {"type": "object"}, - ] - regexes = [to_regex(resolver, t, whitespace_pattern) for t in types] - regexes = [rf"({r})" for r in regexes] - return rf"{'|'.join(regexes)}" - - elif "properties" in instance: - regex = "" - regex += r"\{" - properties = instance["properties"] - required_properties = instance.get("required", []) - is_required = [item in required_properties for item in properties] - # If at least one property is required, we include the one in the lastest position - # without any comma. - # For each property before it (optional or required), we add with a comma after the property. - # For each property after it (optional), we add with a comma before the property. - if any(is_required): - last_required_pos = max([i for i, value in enumerate(is_required) if value]) - for i, (name, value) in enumerate(properties.items()): - subregex = f'{whitespace_pattern}"{re.escape(name)}"{whitespace_pattern}:{whitespace_pattern}' - subregex += to_regex(resolver, value, whitespace_pattern) - if i < last_required_pos: - subregex = f"{subregex}{whitespace_pattern}," - elif i > last_required_pos: - subregex = f"{whitespace_pattern},{subregex}" - regex += subregex if is_required[i] else f"({subregex})?" - # If no property is required, we have to create a possible pattern for each property in which - # it's the last one necessarilly present. Then, we add the others as optional before and after - # following the same strategy as described above. - # The whole block is made optional to allow the case in which no property is returned. - else: - property_subregexes = [] - for i, (name, value) in enumerate(properties.items()): - subregex = f'{whitespace_pattern}"{name}"{whitespace_pattern}:{whitespace_pattern}' - subregex += to_regex(resolver, value, whitespace_pattern) - property_subregexes.append(subregex) - possible_patterns = [] - for i in range(len(property_subregexes)): - pattern = "" - for subregex in property_subregexes[:i]: - pattern += f"({subregex}{whitespace_pattern},)?" - pattern += property_subregexes[i] - for subregex in property_subregexes[i + 1 :]: - pattern += f"({whitespace_pattern},{subregex})?" - possible_patterns.append(pattern) - regex += f"({'|'.join(possible_patterns)})?" - - regex += f"{whitespace_pattern}" + r"\}" - - return regex - - # To validate against allOf, the given data must be valid against all of the - # given subschemas. - elif "allOf" in instance: - subregexes = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["allOf"] - ] - subregexes_str = [f"{subregex}" for subregex in subregexes] - return rf"({''.join(subregexes_str)})" - - # To validate against `anyOf`, the given data must be valid against - # any (one or more) of the given subschemas. - elif "anyOf" in instance: - subregexes = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["anyOf"] - ] - return rf"({'|'.join(subregexes)})" - - # To validate against oneOf, the given data must be valid against exactly - # one of the given subschemas. - elif "oneOf" in instance: - subregexes = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["oneOf"] - ] - - xor_patterns = [f"(?:{subregex})" for subregex in subregexes] - - return rf"({'|'.join(xor_patterns)})" - - # Create pattern for Tuples, per JSON Schema spec, `prefixItems` determines types at each idx - elif "prefixItems" in instance: - element_patterns = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["prefixItems"] - ] - comma_split_pattern = rf"{whitespace_pattern},{whitespace_pattern}" - tuple_inner = comma_split_pattern.join(element_patterns) - return rf"\[{whitespace_pattern}{tuple_inner}{whitespace_pattern}\]" - - # The enum keyword is used to restrict a value to a fixed set of values. It - # must be an array with at least one element, where each element is unique. - elif "enum" in instance: - choices = [] - for choice in instance["enum"]: - if type(choice) in [int, float, bool, type(None), str]: - choices.append(re.escape(json.dumps(choice))) - else: - raise TypeError(f"Unsupported data type in enum: {type(choice)}") - return f"({'|'.join(choices)})" - - elif "const" in instance: - const = instance["const"] - if type(const) in [int, float, bool, type(None), str]: - const = re.escape(json.dumps(const)) - else: - raise TypeError(f"Unsupported data type in const: {type(const)}") - return const - - elif "$ref" in instance: - path = f"{instance['$ref']}" - instance = resolver.lookup(path).contents - return to_regex(resolver, instance, whitespace_pattern) - - # The type keyword may either be a string or an array: - # - If it's a string, it is the name of one of the basic types. - # - If it is an array, it must be an array of strings, where each string is - # the name of one of the basic types, and each element is unique. In this - # case, the JSON snippet is valid if it matches any of the given types. - elif "type" in instance: - instance_type = instance["type"] - if instance_type == "string": - if "maxLength" in instance or "minLength" in instance: - max_items = instance.get("maxLength", "") - min_items = instance.get("minLength", "") - try: - if int(max_items) < int(min_items): - raise ValueError( - "maxLength must be greater than or equal to minLength" - ) # FIXME this raises an error but is caught right away by the except (meant for int("") I assume) - except ValueError: - pass - return f'"{STRING_INNER}{{{min_items},{max_items}}}"' - elif "pattern" in instance: - pattern = instance["pattern"] - if pattern[0] == "^" and pattern[-1] == "$": - return rf'("{pattern[1:-1]}")' - else: - return rf'("{pattern}")' - elif "format" in instance: - format = instance["format"] - if format == "date-time": - return format_to_regex["date-time"] - elif format == "uuid": - return format_to_regex["uuid"] - elif format == "date": - return format_to_regex["date"] - elif format == "time": - return format_to_regex["time"] - else: - raise NotImplementedError( - f"Format {format} is not supported by Outlines" - ) - else: - return type_to_regex["string"] - - elif instance_type == "number": - bounds = { - "minDigitsInteger", - "maxDigitsInteger", - "minDigitsFraction", - "maxDigitsFraction", - "minDigitsExponent", - "maxDigitsExponent", - } - if bounds.intersection(set(instance.keys())): - min_digits_integer, max_digits_integer = validate_quantifiers( - instance.get("minDigitsInteger"), - instance.get("maxDigitsInteger"), - start_offset=1, - ) - min_digits_fraction, max_digits_fraction = validate_quantifiers( - instance.get("minDigitsFraction"), instance.get("maxDigitsFraction") - ) - min_digits_exponent, max_digits_exponent = validate_quantifiers( - instance.get("minDigitsExponent"), instance.get("maxDigitsExponent") - ) - integers_quantifier = ( - f"{{{min_digits_integer},{max_digits_integer}}}" - if min_digits_integer or max_digits_integer - else "*" - ) - fraction_quantifier = ( - f"{{{min_digits_fraction},{max_digits_fraction}}}" - if min_digits_fraction or max_digits_fraction - else "+" - ) - exponent_quantifier = ( - f"{{{min_digits_exponent},{max_digits_exponent}}}" - if min_digits_exponent or max_digits_exponent - else "+" - ) - return rf"((-)?(0|[1-9][0-9]{integers_quantifier}))(\.[0-9]{fraction_quantifier})?([eE][+-][0-9]{exponent_quantifier})?" - return type_to_regex["number"] - - elif instance_type == "integer": - if "minDigits" in instance or "maxDigits" in instance: - min_digits, max_digits = validate_quantifiers( - instance.get("minDigits"), instance.get("maxDigits"), start_offset=1 - ) - return rf"(-)?(0|[1-9][0-9]{{{min_digits},{max_digits}}})" - return type_to_regex["integer"] - - elif instance_type == "array": - num_repeats = _get_num_items_pattern( - instance.get("minItems"), instance.get("maxItems"), whitespace_pattern - ) - if num_repeats is None: - return rf"\[{whitespace_pattern}\]" - - allow_empty = "?" if int(instance.get("minItems", 0)) == 0 else "" - - if "items" in instance: - items_regex = to_regex(resolver, instance["items"], whitespace_pattern) - return rf"\[{whitespace_pattern}(({items_regex})(,{whitespace_pattern}({items_regex})){num_repeats}){allow_empty}{whitespace_pattern}\]" - else: - # Here we need to make the choice to exclude generating list of objects - # if the specification of the object is not given, even though a JSON - # object that contains an object here would be valid under the specification. - legal_types = [ - {"type": "boolean"}, - {"type": "null"}, - {"type": "number"}, - {"type": "integer"}, - {"type": "string"}, - ] - depth = instance.get("depth", 2) - if depth > 0: - legal_types.append({"type": "object", "depth": depth - 1}) - legal_types.append({"type": "array", "depth": depth - 1}) - - regexes = [ - to_regex(resolver, t, whitespace_pattern) for t in legal_types - ] - return rf"\[{whitespace_pattern}({'|'.join(regexes)})(,{whitespace_pattern}({'|'.join(regexes)})){num_repeats}{allow_empty}{whitespace_pattern}\]" - - elif instance_type == "object": - # pattern for json object with values defined by instance["additionalProperties"] - # enforces value type constraints recursively, "minProperties", and "maxProperties" - # doesn't enforce "required", "dependencies", "propertyNames" "any/all/on Of" - num_repeats = _get_num_items_pattern( - instance.get("minProperties"), - instance.get("maxProperties"), - whitespace_pattern, - ) - if num_repeats is None: - return rf"\{{{whitespace_pattern}\}}" - - allow_empty = "?" if int(instance.get("minProperties", 0)) == 0 else "" - - additional_properties = instance.get("additionalProperties") - - if additional_properties is None or additional_properties is True: - # JSON Schema behavior: If the additionalProperties of an object is - # unset or True, it is unconstrained object. - # We handle this by setting additionalProperties to anyOf: {all types} - - legal_types = [ - {"type": "string"}, - {"type": "number"}, - {"type": "boolean"}, - {"type": "null"}, - ] - - # We set the object depth to 2 to keep the expression finite, but the "depth" - # key is not a true component of the JSON Schema specification. - depth = instance.get("depth", 2) - if depth > 0: - legal_types.append({"type": "object", "depth": depth - 1}) - legal_types.append({"type": "array", "depth": depth - 1}) - additional_properties = {"anyOf": legal_types} - - value_pattern = to_regex( - resolver, additional_properties, whitespace_pattern - ) - key_value_pattern = ( - f"{STRING}{whitespace_pattern}:{whitespace_pattern}{value_pattern}" - ) - key_value_successor_pattern = ( - f"{whitespace_pattern},{whitespace_pattern}{key_value_pattern}" - ) - multiple_key_value_pattern = f"({key_value_pattern}({key_value_successor_pattern}){num_repeats}){allow_empty}" - - return ( - r"\{" - + whitespace_pattern - + multiple_key_value_pattern - + whitespace_pattern - + r"\}" - ) - - elif instance_type == "boolean": - return type_to_regex["boolean"] - - elif instance_type == "null": - return type_to_regex["null"] - - elif isinstance(instance_type, list): - # Here we need to make the choice to exclude generating an object - # if the specification of the object is not give, even though a JSON - # object that contains an object here would be valid under the specification. - regexes = [ - to_regex(resolver, {"type": t}, whitespace_pattern) - for t in instance_type - if t != "object" - ] - return rf"({'|'.join(regexes)})" - - raise NotImplementedError( - f"""Could not translate the instance {instance} to a - regular expression. Make sure it is valid to the JSON Schema specification. If - it is, please open an issue on the Outlines repository""" - ) - - -def get_schema_from_signature(fn: Callable) -> str: - """Turn a function signature into a JSON schema. - - Every JSON object valid to the output JSON Schema can be passed - to `fn` using the ** unpacking syntax. - - """ - signature = inspect.signature(fn) - arguments = {} - for name, arg in signature.parameters.items(): - if arg.annotation == inspect._empty: - raise ValueError("Each argument must have a type annotation") - else: - arguments[name] = (arg.annotation, ...) - - try: - fn_name = fn.__name__ - except Exception as e: - fn_name = "Arguments" - warnings.warn( - f"The function name could not be determined. Using default name 'Arguments' instead. For debugging, here is exact error:\n{e}", - category=UserWarning, - ) - model = create_model(fn_name, **arguments) - - return model.model_json_schema() diff --git a/python/outlines_core/py.typed b/python/outlines_core/py.typed deleted file mode 100644 index e69de29b..00000000 diff --git a/setup.py b/setup.py deleted file mode 100644 index 4c414e6b..00000000 --- a/setup.py +++ /dev/null @@ -1,20 +0,0 @@ -import os - -from setuptools import setup -from setuptools_rust import Binding, RustExtension - -CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) - - -rust_extensions = [ - RustExtension( - "outlines_core.fsm.outlines_core_rs", - f"{CURRENT_DIR}/Cargo.toml", - binding=Binding.PyO3, - rustc_flags=["--crate-type=cdylib"], - ), -] - -setup( - rust_extensions=rust_extensions, -) diff --git a/src/lib.rs b/src/lib.rs deleted file mode 100644 index 534b0bb7..00000000 --- a/src/lib.rs +++ /dev/null @@ -1,23 +0,0 @@ -mod regex; - -use pyo3::prelude::*; -use pyo3::wrap_pyfunction; -use regex::_walk_fsm; -use regex::create_fsm_index_end_to_end; -use regex::get_token_transition_keys; -use regex::get_vocabulary_transition_keys; -use regex::state_scan_tokens; -use regex::FSMInfo; - -#[pymodule] -fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_function(wrap_pyfunction!(_walk_fsm, m)?)?; - m.add_function(wrap_pyfunction!(state_scan_tokens, m)?)?; - m.add_function(wrap_pyfunction!(get_token_transition_keys, m)?)?; - m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys, m)?)?; - m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end, m)?)?; - - m.add_class::()?; - - Ok(()) -} diff --git a/tests/fsm/test_fsm.py b/tests/fsm/test_fsm.py deleted file mode 100644 index aeb7060c..00000000 --- a/tests/fsm/test_fsm.py +++ /dev/null @@ -1,91 +0,0 @@ -import pytest -from outlines_core.fsm.fsm import RegexFSM, StopAtEosFSM - - -def assert_expected_tensor_ids(tensor, ids): - assert len(tensor) == len(ids) - norm_tensor = sorted(map(int, tensor)) - norm_ids = sorted(map(int, tensor)) - assert norm_tensor == norm_ids, (norm_tensor, norm_ids) - - -def test_stop_at_eos(): - class MockTokenizer: - vocabulary = {"a": 1, "eos": 2} - eos_token_id = 2 - - with pytest.warns(UserWarning): - fsm = StopAtEosFSM(MockTokenizer()) - - assert fsm.allowed_token_ids(fsm.start_state) is None - assert fsm.allowed_token_ids(fsm.final_state) == [2] - assert fsm.next_state(fsm.start_state, 2) == fsm.final_state - assert fsm.next_state(fsm.start_state, 1) == fsm.start_state - assert fsm.is_final_state(fsm.start_state) is False - assert fsm.is_final_state(fsm.final_state) is True - - -def test_regex_vocabulary_error(): - class MockTokenizer: - vocabulary = {"a": 1} - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - regex_str = "[1-9]" - - with pytest.raises(ValueError, match="The vocabulary"): - RegexFSM(regex_str, MockTokenizer()) - - -def test_regex(): - class MockTokenizer: - vocabulary = {"1": 1, "a": 2, "eos": 3} - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - regex_str = "[1-9]" - tokenizer = MockTokenizer() - - with pytest.warns(UserWarning): - fsm = RegexFSM(regex_str, tokenizer) - - assert fsm.states_to_token_maps == {0: {1: 1}} - assert_expected_tensor_ids(fsm.allowed_token_ids(state=0), [1]) - assert fsm.next_state(state=0, token_id=1) == 1 - assert fsm.next_state(state=0, token_id=tokenizer.eos_token_id) == -1 - - assert fsm.is_final_state(0) is False - - for state in fsm.final_states: - assert fsm.is_final_state(state) is True - - -def test_regex_final_state(): - """Make sure that the FSM stays in the final state as we keep generating""" - - class MockTokenizer: - vocabulary = {"`": 101, ".": 102, "\n": 103, "eos": 104} - special_tokens = {"eos"} - eos_token_id = 104 - - def convert_token_to_string(self, token): - return token - - regex_str = r"`\n(\.\n)?`\n" - tokenizer = MockTokenizer() - - with pytest.warns(UserWarning): - fsm = RegexFSM(regex_str, tokenizer) - - state = fsm.next_state(state=4, token_id=103) - assert state == 5 - assert fsm.is_final_state(state) - - state = fsm.next_state(state=5, token_id=103) - assert fsm.is_final_state(state) diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py deleted file mode 100644 index 0bd28d4f..00000000 --- a/tests/fsm/test_guide.py +++ /dev/null @@ -1,189 +0,0 @@ -import pytest -from outlines_core.fsm.guide import Generate, RegexGuide, StopAtEOSGuide, Write - - -def assert_expected_tensor_ids(tensor, ids): - assert len(tensor) == len(ids) - norm_tensor = sorted(map(int, tensor)) - norm_ids = sorted(map(int, tensor)) - assert norm_tensor == norm_ids, (norm_tensor, norm_ids) - - -def test_stop_at_eos(): - class MockTokenizer: - vocabulary = {"a": 1, "eos": 2} - eos_token_id = 2 - - fsm = StopAtEOSGuide(MockTokenizer()) - - instruction = fsm.get_next_instruction(fsm.start_state) - assert isinstance(instruction, Generate) - assert instruction.tokens is None - - instruction = fsm.get_next_instruction(fsm.final_state) - assert isinstance(instruction, Write) - assert instruction.tokens == [2] - - assert fsm.get_next_state(fsm.start_state, 2) == fsm.final_state - assert fsm.get_next_state(fsm.start_state, 1) == fsm.start_state - assert fsm.is_final_state(fsm.start_state) is False - assert fsm.is_final_state(fsm.final_state) is True - - -def test_regex_vocabulary_error(): - class MockTokenizer: - vocabulary = {"a": 1} - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - regex_str = "[1-9]" - - with pytest.raises(ValueError, match="The vocabulary"): - RegexGuide(regex_str, MockTokenizer()) - - -def test_regex(): - class MockTokenizer: - vocabulary = {"1": 1, "a": 2, "eos": 3} - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - return token - - regex_str = "[1-9]" - tokenizer = MockTokenizer() - fsm = RegexGuide(regex_str, tokenizer) - - assert fsm.states_to_token_maps == {0: {1: 1}} - - instruction = fsm.get_next_instruction(0) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [1]) - - assert fsm.get_next_state(state=0, token_id=1) == 1 - assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 - - assert fsm.is_final_state(0) is False - - for state in fsm.final_states: - assert fsm.is_final_state(state) is True - - -def test_regex_multi_byte_llama_like(): - class MockTokenizer: - vocabulary = { - "1": 1, - "a": 2, - "eos": 3, - "😍": 4, - "<0xF0>": 5, - "<0x9F>": 6, - "<0x98>": 7, - "<0x88>": 8, # 😈 - "\ufffd": 9, - "\ufffd\ufffd": 10, - } - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - if token[0] == "<": - return "\ufffd" - return token - - regex_str = "[😁-😎]" - tokenizer = MockTokenizer() - fsm = RegexGuide(regex_str, tokenizer) - - assert fsm.states_to_token_maps == { - 0: {5: 1, 4: 2}, - 1: {6: 3}, - 3: {7: 4}, - 4: {8: 2}, - } - - instruction = fsm.get_next_instruction(0) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [5, 4]) - - assert fsm.get_next_state(state=0, token_id=5) == 1 - assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 - - assert fsm.is_final_state(0) is False - - for state in fsm.final_states: - assert fsm.is_final_state(state) is True - - -def test_regex_multi_byte_gpt2_like(): - class MockTokenizer: - vocabulary = { - "1": 1, - "a": 2, - "eos": 3, - "😍": 4, - " ": 5, - "\ufffd": 6, - "\ufffd\ufffd": 7, - "ðŁĺ": 8, - "Δͺ": 9, # '😈' - "Δ Γ°": 10, - "ŁĺΔͺ": 11, # ' 😈' - } - special_tokens = {"eos"} - eos_token_id = 3 - - def convert_token_to_string(self, token): - if self.vocabulary[token] >= 8: - return "\ufffd" - return token - - regex_str = " [😁-😎]" - tokenizer = MockTokenizer() - fsm = RegexGuide(regex_str, tokenizer) - - assert fsm.states_to_token_maps == { - 0: {5: 1, 10: 2}, - 1: {8: 5, 4: 3}, - 2: {11: 3}, - 5: {9: 3}, - } - - instruction = fsm.get_next_instruction(0) - assert isinstance(instruction, Generate) - assert_expected_tensor_ids(instruction.tokens, [5, 10]) - - assert fsm.get_next_state(state=0, token_id=5) == 1 - assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 - - assert fsm.is_final_state(0) is False - - for state in fsm.final_states: - assert fsm.is_final_state(state) is True - - -def test_regex_final_state(): - """Make sure that the FSM stays in the final state as we keep generating""" - - class MockTokenizer: - vocabulary = {"`": 101, ".": 102, "\n": 103, "eos": 104} - special_tokens = {"eos"} - eos_token_id = 104 - - def convert_token_to_string(self, token): - return token - - regex_str = r"`\n(\.\n)?`\n" - tokenizer = MockTokenizer() - fsm = RegexGuide(regex_str, tokenizer) - - state = fsm.get_next_state(state=4, token_id=103) - assert state == 5 - assert fsm.is_final_state(state) - - state = fsm.get_next_state(state=5, token_id=103) - assert fsm.is_final_state(state) diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py deleted file mode 100644 index 3fa3d79c..00000000 --- a/tests/fsm/test_json_schema.py +++ /dev/null @@ -1,1040 +0,0 @@ -import json -import re -from typing import List, Literal, Union - -import interegular -import pytest -from outlines_core.fsm.json_schema import ( - BOOLEAN, - DATE, - DATE_TIME, - INTEGER, - NULL, - NUMBER, - STRING, - STRING_INNER, - TIME, - UUID, - WHITESPACE, - build_regex_from_schema, - get_schema_from_signature, - to_regex, -) -from pydantic import BaseModel, Field, constr - - -def test_function_basic(): - def test_function(foo: str, bar: List[int]): - pass - - result = get_schema_from_signature(test_function) - assert result["type"] == "object" - assert list(result["properties"].keys()) == ["foo", "bar"] - assert result["properties"]["foo"]["type"] == "string" - assert result["properties"]["bar"]["type"] == "array" - assert result["properties"]["bar"]["items"]["type"] == "integer" - - -def test_function_no_type(): - def test_function(foo, bar: List[int]): - pass - - with pytest.raises(ValueError): - get_schema_from_signature(test_function) - - -def test_from_pydantic(): - class User(BaseModel): - user_id: int - name: str - maxlength_name: constr(max_length=10) - minlength_name: constr(min_length=10) - value: float - is_true: bool - - schema = json.dumps(User.model_json_schema()) - schedule = build_regex_from_schema(schema) - assert isinstance(schedule, str) - - -@pytest.mark.parametrize( - "pattern,does_match", - [ - ({"integer": "0"}, True), - ({"integer": "1"}, True), - ({"integer": "-1"}, True), - ({"integer": "01"}, False), - ({"integer": "1.3"}, False), - ({"integer": "t"}, False), - ], -) -def test_match_integer(pattern, does_match): - step = {"title": "Foo", "type": "integer"} - regex = to_regex(None, step) - assert regex == INTEGER - - value = pattern["integer"] - match = re.fullmatch(regex, value) - if does_match: - assert match[0] == value - assert match.span() == (0, len(value)) - else: - assert match is None - - -@pytest.mark.parametrize( - "pattern,does_match", - [ - ({"number": "1"}, True), - ({"number": "0"}, True), - ({"number": "01"}, False), - ({"number": ".3"}, False), - ({"number": "1.3"}, True), - ({"number": "-1.3"}, True), - ({"number": "1.3e9"}, False), - ({"number": "1.3e+9"}, True), - ], -) -def test_match_number(pattern, does_match): - step = {"title": "Foo", "type": "number"} - regex = to_regex(None, step) - assert regex == NUMBER - - value = pattern["number"] - match = re.fullmatch(regex, value) - if does_match: - assert match[0] == value - assert match.span() == (0, len(value)) - else: - assert match is None - - -@pytest.mark.parametrize( - "schema,regex,examples", - [ - # String - ( - {"title": "Foo", "type": "string"}, - STRING, - [ - ("unquotedstring", False), - ('"(parenthesized_string)"', True), - ('"malformed) parenthesis (((() string"', True), - ('"quoted_string"', True), - (r'"escape_\character"', False), - (r'"double_\\escape"', True), - (r'"\n"', False), - (r'"\\n"', True), - (r'"unescaped " quote"', False), - (r'"escaped \" quote"', True), - ], - ), - # String with maximum length - ( - {"title": "Foo", "type": "string", "maxLength": 3}, - f'"{STRING_INNER}{{,3}}"', - [('"ab"', True), ('"a""', False), ('"abcd"', False)], - ), - # String with minimum length - ( - {"title": "Foo", "type": "string", "minLength": 3}, - f'"{STRING_INNER}{{3,}}"', - [('"ab"', False), ('"abcd"', True), ('"abc""', False)], - ), - # String with both minimum and maximum length - ( - {"title": "Foo", "type": "string", "minLength": 3, "maxLength": 5}, - f'"{STRING_INNER}{{3,5}}"', - [('"ab"', False), ('"abcd"', True), ('"abcdef""', False)], - ), - # String defined by a regular expression - ( - {"title": "Foo", "type": "string", "pattern": r"^[a-z]$"}, - r'("[a-z]")', - [('"a"', True), ('"1"', False)], - ), - # Boolean - ( - {"title": "Foo", "type": "boolean"}, - BOOLEAN, - [ - ("true", True), - ("false", True), - ("null", False), - ("0", False), - ], - ), - # Null - ( - {"title": "Foo", "type": "null"}, - NULL, - [ - ("null", True), - ("true", False), - ("0", False), - ], - ), - # Const string - ( - {"title": "Foo", "const": "Marc", "type": "string"}, - '"Marc"', - [('"Marc"', True), ('"Jean"', False), ('"John"', False)], - ), - # Make sure strings are escaped with regex escaping - ( - {"title": "Foo", "const": ".*", "type": "string"}, - r'"\.\*"', - [('".*"', True), (r'"\s*"', False), (r'"\.\*"', False)], - ), - # Make sure strings are escaped with JSON escaping - ( - {"title": "Foo", "const": '"', "type": "string"}, - r'"\\""', - [('"\\""', True), ('"""', False)], - ), - # Const integer - ( - {"title": "Foo", "const": 0, "type": "integer"}, - "0", - [("0", True), ("1", False), ("a", False)], - ), - # Const float - ( - {"title": "Foo", "const": 0.2, "type": "float"}, - r"0\.2", - [("0.2", True), ("032", False)], - ), - # Const boolean - ( - {"title": "Foo", "const": True, "type": "boolean"}, - "true", - [("true", True), ("True", False)], - ), - # Const null - ( - {"title": "Foo", "const": None, "type": "null"}, - "null", - [("null", True), ("None", False), ("", False)], - ), - # Enum string - ( - {"title": "Foo", "enum": ["Marc", "Jean"], "type": "string"}, - '("Marc"|"Jean")', - [('"Marc"', True), ('"Jean"', True), ('"John"', False)], - ), - # Make sure strings are escaped with regex and JSON escaping - ( - {"title": "Foo", "enum": [".*", r"\s*"], "type": "string"}, - r'("\.\*"|"\\\\s\*")', - [('".*"', True), (r'"\\s*"', True), (r'"\.\*"', False)], - ), - # Enum integer - ( - {"title": "Foo", "enum": [0, 1], "type": "integer"}, - "(0|1)", - [("0", True), ("1", True), ("a", False)], - ), - # Enum mix of types - ( - {"title": "Foo", "enum": [6, 5.3, "potato", True, None]}, - r'(6|5\.3|"potato"|true|null)', - [ - ("6", True), - ("5.3", True), - ('"potato"', True), - ("true", True), - ("null", True), - ("523", False), - ("True", False), - ("None", False), - ], - ), - # integer - ( - { - "title": "Foo", - "type": "object", - "properties": {"count": {"title": "Count", "type": "integer"}}, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)[ ]?\\}', - [('{ "count": 100 }', True)], - ), - # integer with minimum digits - ( - { - "title": "Foo", - "type": "object", - "properties": { - "count": {"title": "Count", "type": "integer", "minDigits": 3} - }, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{2,})[ ]?\\}', - [('{ "count": 10 }', False), ('{ "count": 100 }', True)], - ), - # integer with maximum digits - ( - { - "title": "Foo", - "type": "object", - "properties": { - "count": {"title": "Count", "type": "integer", "maxDigits": 3} - }, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{,2})[ ]?\\}', - [('{ "count": 100 }', True), ('{ "count": 1000 }', False)], - ), - # integer with minimum and maximum digits - ( - { - "title": "Foo", - "type": "object", - "properties": { - "count": { - "title": "Count", - "type": "integer", - "minDigits": 3, - "maxDigits": 5, - } - }, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{2,4})[ ]?\\}', - [ - ('{ "count": 10 }', False), - ('{ "count": 100 }', True), - ('{ "count": 10000 }', True), - ('{ "count": 100000 }', False), - ], - ), - # number - ( - { - "title": "Foo", - "type": "object", - "properties": {"count": {"title": "Count", "type": "number"}}, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\\.[0-9]+)?([eE][+-][0-9]+)?[ ]?\\}', - [('{ "count": 100 }', True), ('{ "count": 100.5 }', True)], - ), - # number with min and max integer digits - ( - { - "title": "Foo", - "type": "object", - "properties": { - "count": { - "title": "Count", - "type": "number", - "minDigitsInteger": 3, - "maxDigitsInteger": 5, - } - }, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]{2,4}))(\\.[0-9]+)?([eE][+-][0-9]+)?[ ]?\\}', - [ - ('{ "count": 10.005 }', False), - ('{ "count": 100.005 }', True), - ('{ "count": 10000.005 }', True), - ('{ "count": 100000.005 }', False), - ], - ), - # number with min and max fraction digits - ( - { - "title": "Foo", - "type": "object", - "properties": { - "count": { - "title": "Count", - "type": "number", - "minDigitsFraction": 3, - "maxDigitsFraction": 5, - } - }, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\\.[0-9]{3,5})?([eE][+-][0-9]+)?[ ]?\\}', - [ - ('{ "count": 1.05 }', False), - ('{ "count": 1.005 }', True), - ('{ "count": 1.00005 }', True), - ('{ "count": 1.000005 }', False), - ], - ), - # number with min and max exponent digits - ( - { - "title": "Foo", - "type": "object", - "properties": { - "count": { - "title": "Count", - "type": "number", - "minDigitsExponent": 3, - "maxDigitsExponent": 5, - } - }, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\\.[0-9]+)?([eE][+-][0-9]{3,5})?[ ]?\\}', - [ - ('{ "count": 1.05e1 }', False), - ('{ "count": 1.05e+001 }', True), - ('{ "count": 1.05e-00001 }', True), - ('{ "count": 1.05e0000001 }', False), - ], - ), - # number with min and max integer, fraction and exponent digits - ( - { - "title": "Foo", - "type": "object", - "properties": { - "count": { - "title": "Count", - "type": "number", - "minDigitsInteger": 3, - "maxDigitsInteger": 5, - "minDigitsFraction": 3, - "maxDigitsFraction": 5, - "minDigitsExponent": 3, - "maxDigitsExponent": 5, - } - }, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]{2,4}))(\\.[0-9]{3,5})?([eE][+-][0-9]{3,5})?[ ]?\\}', - [ - ('{ "count": 1.05e1 }', False), - ('{ "count": 100.005e+001 }', True), - ('{ "count": 10000.00005e-00001 }', True), - ('{ "count": 100000.000005e0000001 }', False), - ], - ), - # array - ( - {"title": "Foo", "type": "array", "items": {"type": "number"}}, - rf"\[{WHITESPACE}(({NUMBER})(,{WHITESPACE}({NUMBER})){{0,}})?{WHITESPACE}\]", - [("[1e+9,1.3]", True), ("[]", True), ("[1", False)], - ), - # array with a set length of 1 - ( - { - "title": "Foo", - "type": "array", - "items": {"type": "integer"}, - "minItems": 1, - "maxItems": 1, - }, - rf"\[{WHITESPACE}(({INTEGER})(,{WHITESPACE}({INTEGER})){{0,0}}){WHITESPACE}\]", - [("[1]", True), ("[1,2]", False), ('["a"]', False), ("[]", False)], - ), - # array with a set length greather than 1 - ( - { - "title": "Foo", - "type": "array", - "items": {"type": "integer"}, - "minItems": 3, - "maxItems": 3, - }, - rf"\[{WHITESPACE}(({INTEGER})(,{WHITESPACE}({INTEGER})){{2,2}}){WHITESPACE}\]", - [("[1]", False), ("[]", False), ("[1,2,3]", True), ("[1,2,3,4]", False)], - ), - # array with length 0 - ( - { - "title": "Foo", - "type": "array", - "items": {"type": "integer"}, - "minItems": 0, - "maxItems": 0, - }, - rf"\[{WHITESPACE}\]", - [("[1]", False), ("[]", True), ("[1,2,3]", False), ("[1,2,3,4]", False)], - ), - # object - ( - { - "title": "TestSchema", - "type": "object", - "properties": { - "test_dict": { - "title": "Test Dict", - "additionalProperties": {"type": "string"}, - "type": "object", - } - }, - "required": ["test_dict"], - }, - rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{STRING}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{STRING}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", - [ - ("""{ "test_dict":{"foo":"bar","baz": "bif"}}""", True), - ("""{ "test_dict":{"foo":"bar" }}""", True), - ("""{ "test_dict":{}}""", True), - ("""{ "WRONG_KEY":{}}""", False), - ("""{ "test_dict":{"wrong_type" 1}}""", False), - ], - ), - # object containing object - ( - { - "title": "TestSchema", - "type": "object", - "properties": { - "test_dict": { - "title": "Test Dict", - "additionalProperties": { - "additionalProperties": {"type": "integer"}, - "type": "object", - }, - "type": "object", - } - }, - "required": ["test_dict"], - }, - rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", - [ - ( - """{"test_dict": {"foo": {"bar": 123, "apple": 99}, "baz": {"bif": 456}}}""", - True, - ), - ( - """{"test_dict": {"anykey": {"anykey": 123}, "anykey2": {"bif": 456}}}""", - True, - ), - ("""{"test_dict": {}}""", True), - ("""{"test_dict": {"dict of empty dicts are ok": {} }}""", True), - ( - """{"test_dict": {"anykey": {"ONLY Dict[Dict]": 123}, "No Dict[int]" 1: }}""", - False, - ), - ], - ), - # oneOf - ( - { - "title": "Foo", - "oneOf": [{"type": "string"}, {"type": "number"}, {"type": "boolean"}], - }, - rf'((?:"{STRING_INNER}*")|(?:{NUMBER})|(?:{BOOLEAN}))', - [ - ("12.3", True), - ("true", True), - ('"a"', True), - ("null", False), - ("", False), - ("12true", False), - ('1.3"a"', False), - ('12.3true"a"', False), - ], - ), - # anyOf - ( - { - "title": "Foo", - "anyOf": [{"type": "string"}, {"type": "integer"}], - }, - rf"({STRING}|{INTEGER})", - [("12", True), ('"a"', True), ('1"a"', False)], - ), - # allOf - ( - { - "title": "Foo", - "allOf": [{"type": "string"}, {"type": "integer"}], - }, - rf"({STRING}{INTEGER})", - [('"a"1', True), ('"a"', False), ('"1"', False)], - ), - # Tuple / prefixItems - ( - { - "title": "Foo", - "prefixItems": [{"type": "string"}, {"type": "integer"}], - }, - rf"\[{WHITESPACE}{STRING}{WHITESPACE},{WHITESPACE}{INTEGER}{WHITESPACE}\]", - [('["a", 1]', True), ('["a", 1, 1]', False), ("[]", False)], - ), - # Nested schema - ( - { - "title": "Bar", - "type": "object", - "properties": { - "fuzz": { - "title": "Foo", - "type": "object", - "properties": {"spam": {"title": "Spam", "type": "integer"}}, - "required": ["spam"], - } - }, - "required": ["fuzz"], - }, - f'\\{{[ ]?"fuzz"[ ]?:[ ]?\\{{[ ]?"spam"[ ]?:[ ]?{INTEGER}[ ]?\\}}[ ]?\\}}', - [('{ "fuzz": { "spam": 100 }}', True)], - ), - # Schema with a reference - ( - { - "title": "User", - "type": "object", - "properties": { - "user_id": {"title": "User Id", "type": "integer"}, - "name": {"title": "Name", "type": "string"}, - "a": {"$ref": "#/properties/name"}, - }, - "required": ["user_id", "name", "a"], - }, - f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"a"[ ]?:[ ]?{STRING}[ ]?\\}}', - [('{"user_id": 100, "name": "John", "a": "Marc"}', True)], - ), - ( - { - "title": "User", - "type": "object", - "$defs": {"name": {"title": "Name2", "type": "string"}}, - "properties": { - "user_id": {"title": "User Id", "type": "integer"}, - "name": {"title": "Name", "type": "string"}, - "name2": {"$ref": "#/$defs/name"}, - }, - "required": ["user_id", "name", "name2"], - }, - f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"name2"[ ]?:[ ]?{STRING}[ ]?\\}}', - [('{"user_id": 100, "name": "John", "name2": "Marc"}', True)], - ), - ( - { - "$id": "customer", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "title": "Customer", - "type": "object", - "properties": { - "name": {"type": "string"}, - "last_name": {"type": "string"}, - "address": {"$ref": "customer#/$defs/address"}, - }, - "required": [ - "name", - "first_name", - "last_name", - "address", - "shipping_address", - "billing_address", - ], - "$defs": { - "address": { - "title": "Address", - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": { - "city": {"type": "string"}, - }, - "required": ["street_address", "city", "state"], - "definitions": { - "state": { - "type": "object", - "title": "State", - "properties": {"name": {"type": "string"}}, - "required": ["name"], - } - }, - } - }, - }, - f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"last_name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"address"[ ]?:[ ]?\\{{[ ]?"city"[ ]?:[ ]?{STRING}[ ]?\\}}[ ]?\\}}', - [ - ( - '{"name": "John", "last_name": "Doe", "address": {"city": "Paris"}}', - True, - ) - ], - ), - # Optional properties - # Last required property in first position - ( - { - "properties": { - "name": {"type": "string"}, - "age": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, - "weapon": {"anyOf": [{"type": "string"}, {"type": "null"}]}, - }, - "required": ["name"], - "title": "Character", - "type": "object", - }, - f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"weapon"[ ]?:[ ]?({STRING}|null))?[ ]?\\}}', - [ - ('{ "name" : "Player" }', True), - ('{ "name" : "Player", "weapon" : "sword" }', True), - ('{ "age" : 10, "weapon" : "sword" }', False), - ], - ), - # Last required property in middle position - ( - { - "properties": { - "name": {"type": "string"}, - "age": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, - "weapon": {"type": "string"}, - "strength": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, - }, - "required": ["name", "weapon"], - "title": "Character", - "type": "object", - }, - f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}', - [ - ('{ "name" : "Player" , "weapon" : "sword" }', True), - ( - '{ "name" : "Player", "age" : 10, "weapon" : "sword" , "strength" : 10 }', - True, - ), - ('{ "weapon" : "sword" }', False), - ], - ), - # Last required property in last position - ( - { - "properties": { - "name": {"anyOf": [{"type": "string"}, {"type": "null"}]}, - "age": {"type": "integer"}, - "armor": {"type": "string"}, - "strength": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, - "weapon": {"title": "Weapon", "type": "string"}, - }, - "required": ["age", "armor", "weapon"], - "title": "Character", - "type": "object", - }, - f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"armor"[ ]?:[ ]?{STRING}[ ]?,([ ]?"strength"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}[ ]?\\}}', - [ - ( - '{ "name" : "Player", "age" : 10, "armor" : "plate", "strength" : 11, "weapon" : "sword" }', - True, - ), - ('{ "age" : 10, "armor" : "plate", "weapon" : "sword" }', True), - ( - '{ "name" : "Kahlhanbeh", "armor" : "plate", "weapon" : "sword" }', - False, - ), - ], - ), - # All properties are optional - ( - { - "properties": { - "name": {"anyOf": [{"type": "string"}, {"type": "null"}]}, - "age": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, - "strength": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, - }, - "title": "Character", - "type": "object", - }, - f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?({INTEGER}|null)([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}', - [ - ('{ "name" : "Player" }', True), - ('{ "name" : "Player", "age" : 10, "strength" : 10 }', True), - ('{ "age" : 10, "strength" : 10 }', True), - ("{ }", True), - ], - ), - ], -) -def test_match(schema, regex, examples): - interegular.parse_pattern(regex) - schema = json.dumps(schema) - test_regex = build_regex_from_schema(schema) - assert test_regex == regex - - for string, does_match in examples: - match = re.fullmatch(test_regex, string) - if does_match: - if match is None: - raise ValueError(f"Expected match for '{string}'") - assert match[0] == string - assert match.span() == (0, len(string)) - else: - assert match is None - - -@pytest.mark.parametrize( - "schema,regex,examples", - [ - # UUID - ( - {"title": "Foo", "type": "string", "format": "uuid"}, - UUID, - [ - ("123e4567-e89b-12d3-a456-426614174000", False), - ('"123e4567-e89b-12d3-a456-426614174000"', True), - ('"123e4567-e89b-12d3-a456-42661417400"', False), - ('"123e4567-e89b-12d3-a456-42661417400g"', False), - ('"123e4567-e89b-12d3-a456-42661417400-"', False), - ('""', False), - ], - ), - # DATE-TIME - ( - {"title": "Foo", "type": "string", "format": "date-time"}, - DATE_TIME, - [ - ("2018-11-13T20:20:39Z", False), - ('"2018-11-13T20:20:39Z"', True), - ('"2016-09-18T17:34:02.666Z"', True), - ('"2008-05-11T15:30:00Z"', True), - ('"2021-01-01T00:00:00"', True), - ('"2022-01-10 07:19:30"', False), # missing T - ('"2022-12-10T10-04-29"', False), # incorrect separator - ('"2023-01-01"', False), - ], - ), - # DATE - ( - {"title": "Foo", "type": "string", "format": "date"}, - DATE, - [ - ("2018-11-13", False), - ('"2018-11-13"', True), - ('"2016-09-18"', True), - ('"2008-05-11"', True), - ('"2015-13-01"', False), # incorrect month - ('"2022-01"', False), # missing day - ('"2022/12/01"', False), # incorrect separator" - ], - ), - # TIME - ( - {"title": "Foo", "type": "string", "format": "time"}, - TIME, - [ - ("20:20:39Z", False), - ('"20:20:39Z"', True), - ('"15:30:00Z"', True), - ('"25:30:00"', False), # incorrect hour - ('"15:30"', False), # missing seconds - ('"15:30:00.000"', False), # missing Z - ('"15-30-00"', False), # incorrect separator - ('"15:30:00+01:00"', False), # incorrect separator - ], - ), - ], -) -def test_format(schema, regex, examples): - interegular.parse_pattern(regex) - schema = json.dumps(schema) - test_regex = build_regex_from_schema(schema) - assert test_regex == regex - - for string, does_match in examples: - match = re.fullmatch(test_regex, string) - if does_match: - assert match[0] == string - assert match.span() == (0, len(string)) - else: - assert match is None - - -@pytest.mark.parametrize( - "schema,examples", - [ - # NESTED UUID - ( - { - "title": "Foo", - "type": "object", - "properties": {"uuid": {"type": "string", "format": "uuid"}}, - }, - [ - ('{"uuid": "123e4567-e89b-12d3-a456-426614174000"}', True), - ('{"uuid":"123e4567-e89b-12d3-a456-42661417400"}', False), - ('{"uuid":"123e4567-e89b-12d3-a456-42661417400g"}', False), - ('{"uuid":"123e4567-e89b-12d3-a456-42661417400-"}', False), - ( - '{"uuid":123e4567-e89b-12d3-a456-426614174000}', - False, - ), # missing quotes for value - ('{"uuid":""}', False), - ], - ), - # NESTED DATE-TIME - ( - { - "title": "Foo", - "type": "object", - "properties": {"dateTime": {"type": "string", "format": "date-time"}}, - }, - [ - ('{"dateTime": "2018-11-13T20:20:39Z"}', True), - ('{"dateTime":"2016-09-18T17:34:02.666Z"}', True), - ('{"dateTime":"2008-05-11T15:30:00Z"}', True), - ('{"dateTime":"2021-01-01T00:00:00"}', True), - ('{"dateTime":"2022-01-10 07:19:30"}', False), # missing T - ('{"dateTime":"2022-12-10T10-04-29"}', False), # incorrect separator - ( - '{"dateTime":2018-11-13T20:20:39Z}', - False, - ), # missing quotes for value - ('{"dateTime":"2023-01-01"}', False), - ], - ), - # NESTED DATE - ( - { - "title": "Foo", - "type": "object", - "properties": {"date": {"type": "string", "format": "date"}}, - }, - [ - ('{"date": "2018-11-13"}', True), - ('{"date":"2016-09-18"}', True), - ('{"date":"2008-05-11"}', True), - ('{"date":"2015-13-01"}', False), # incorrect month - ('{"date":"2022-01"}', False), # missing day - ('{"date":"2022/12/01"}', False), # incorrect separator" - ('{"date":2018-11-13}', False), # missing quotes for value - ], - ), - # NESTED TIME - ( - { - "title": "Foo", - "type": "object", - "properties": {"time": {"type": "string", "format": "time"}}, - }, - [ - ('{"time": "20:20:39Z"}', True), - ('{"time":"15:30:00Z"}', True), - ('{"time":"25:30:00"}', False), # incorrect hour - ('{"time":"15:30"}', False), # missing seconds - ('{"time":"15:30:00.000"}', False), # missing Z - ('{"time":"15-30-00"}', False), # incorrect separator - ('{"time":"15:30:00+01:00"}', False), # incorrect separator - ('{"time":20:20:39Z}', False), # missing quotes for value - ], - ), - # Unconstrained Object - ( - { - "title": "Foo", - "type": "object", - }, - [ - ("{}", True), - ('{"a": 1, "b": null}', True), - ('{"a": {"z": {"g": 4}}, "b": null}', True), - ("1234", False), # not an object - ('["a", "a"]', False), # not an array - ], - ), - # Unconstrained Array - ( - { - "type": "array", - }, - [ - ("[1, {}, false]", True), - ("[{}]", True), - ('[{"a": {"z": "q"}, "b": null}]', True), - ('[{"a": [1, 2, true], "b": null}]', True), - ('[{"a": [1, 2, true], "b": {"a": "b"}}, 1, true, [1, [2]]]', True), - # too deep, default unconstrained depth limit = 2 - ( - '[{"a": [1, 2, true], "b": {"a": "b"}}, 1, true, [1, [2, [3]]]]', - False, - ), - ('[{"a": {"z": {"g": 4}}, "b": null}]', False), - ("[[[[1]]]]", False), - # not an array - ("{}", False), - ('{"a": 1, "b": null}', False), - ('{"a": {"z": {"g": 4}}, "b": null}', False), - ("1234", False), # not an array - ('{"a": "a"}', False), # not an array - ], - ), - # No schema / unconstrained value - ( - {}, - [ - ('"aaabbuecuh"', True), # string - ("5.554", True), # number - ("true", True), # boolean - ("null", True), # null - ("5999", True), # integer - ('["a", "b"]', True), # array - ('{"key": {"k2": "value"}}', True), # nested object - ("this isnt valid json", False), - ], - ), - ], -) -def test_format_without_regex(schema, examples): - schema = json.dumps(schema) - test_regex = build_regex_from_schema(schema) - for string, does_match in examples: - match = re.fullmatch(test_regex, string) - if does_match: - assert match[0] == string - assert match.span() == (0, len(string)) - else: - assert match is None - - -@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]*", "abc"]) -def test_json_schema_custom_whitespace_pattern(whitespace_pattern): - """assert whitespace_pattern setting respected""" - - class MockModel(BaseModel): - foo: int - bar: str - - schema = json.dumps(MockModel.model_json_schema()) - - # assert any ws pattern can be used - if whitespace_pattern == "abc": - build_regex_from_schema(schema, whitespace_pattern) - return - - pattern = build_regex_from_schema(schema, whitespace_pattern) - - mock_result_mult_ws = ( - """{ "foo" : 4, \n\n\n "bar": "baz baz baz bar"\n\n}""" - ) - mock_result_maybe_ws = """{"foo" : 4 ,"bar":"baz baz baz bar"}""" - - match_default_ws = re.fullmatch(pattern, mock_result_maybe_ws) - if whitespace_pattern is None: - assert match_default_ws - else: - assert re.fullmatch(pattern, mock_result_mult_ws) - - -def test_one_of_doesnt_produce_illegal_lookaround(): - """Reproduces failure in https://github.com/outlines-dev/outlines/issues/823""" - - class Cat(BaseModel): - pet_type: Literal["cat"] - meows: int - - class Dog(BaseModel): - pet_type: Literal["dog"] - barks: float - - class Model(BaseModel): - pet: Union[Cat, Dog] = Field(..., discriminator="pet_type") - n: int - - json_schema = Model.schema_json() - - json_schema = Model.schema_json() - pattern = build_regex_from_schema(json_schema, whitespace_pattern=None) - - # check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm() - interegular.parse_pattern(pattern).to_fsm() diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py deleted file mode 100644 index b711ba48..00000000 --- a/tests/fsm/test_regex.py +++ /dev/null @@ -1,524 +0,0 @@ -import interegular -import numpy as np -import pytest -from outlines_core.fsm.regex import ( - BetterAlphabet, - BetterFSM, - _walk_fsm, - create_fsm_index_end_to_end, - create_fsm_index_tokenizer, - get_token_transition_keys, - get_vocabulary_transition_keys, - make_byte_level_fsm, - make_deterministic_fsm, - reduced_vocabulary, -) -from outlines_core.integrations.utils import adapt_tokenizer -from outlines_core.models.transformers import TransformerTokenizer -from transformers import AutoTokenizer - - -def identity(s): - return s - - -def to_bytes(s): - return [chr(b) if b < 0x80 else f"\x00{b:02X}" for b in s.encode("utf-8")] - - -def merge_symbols(byte_hexs): - return "".join(["\x00" + b if len(b) == 2 else b for b in byte_hexs]) - - -def token_str_to_trans_key(fsm, input_string): - return get_token_transition_keys( - fsm.fsm_info.alphabet_symbol_mapping, - fsm.fsm_info.alphabet_anything_value, - input_string, - ) - - -def walk_fsm_from_token_str_rust( - fsm, - input_string: str, - start_state: int, - full_match: bool = True, -): - return _walk_fsm( - fsm.fsm_info.transitions, - fsm.fsm_info.initial, - fsm.fsm_info.finals, - token_str_to_trans_key(fsm, input_string), - start_state, - full_match=full_match, - ) - - -def make_byte_level_better_fsm(fsm: BetterFSM, keep_utf8=False) -> BetterFSM: - new_fsm = make_byte_level_fsm(fsm, keep_utf8) - return BetterFSM( - alphabet=BetterAlphabet(new_fsm.alphabet._symbol_mapping), - states=new_fsm.states, - initial=new_fsm.initial, - finals=new_fsm.finals, - map=new_fsm.map, - ) - - -def test_walk_fsm(): - regex_pattern = interegular.parse_pattern("0|[1-9][2-9]*") - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - - res = tuple( - walk_fsm_from_token_str_rust(regex_fsm, "0", regex_fsm.initial, full_match=True) - ) - assert res == (1,) - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, "00", regex_fsm.initial, full_match=False - ) - ) - assert res == (1,) - - res = tuple( - walk_fsm_from_token_str_rust(regex_fsm, "!", regex_fsm.initial, full_match=True) - ) - assert res == tuple() - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, "00", regex_fsm.initial, full_match=True - ) - ) - assert res == tuple() - - # This should fail, because state `1` reads nothing - res = tuple(walk_fsm_from_token_str_rust(regex_fsm, "0", 1, full_match=True)) - assert res == tuple() - - regex_pattern = interegular.parse_pattern("0|[1-9][2-9]+") - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - - res = tuple( - walk_fsm_from_token_str_rust(regex_fsm, "1", regex_fsm.initial, full_match=True) - ) - assert res == tuple() - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, "1", regex_fsm.initial, full_match=False - ) - ) - assert res == (2,) - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, "12", regex_fsm.initial, full_match=True - ) - ) - assert res == (2, 3) - - pattern = interegular.parse_pattern(r"(?:[^\W\d]\w*|[\t \x0c]+)") - fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce()) - - res = tuple(walk_fsm_from_token_str_rust(fsm, "x ", fsm.initial, full_match=False)) - assert res == (2,) - - start_state = list(fsm.finals)[0] - res = tuple(walk_fsm_from_token_str_rust(fsm, "!", start_state, full_match=False)) - assert res == tuple() - - -@pytest.mark.parametrize( - "transform", - [ - identity, - to_bytes, - ], -) -def test_walk_fsm_multi_bytes(transform): - regex_pattern = interegular.parse_pattern("πŸ˜‚|[πŸ˜‡-😍][😈-😍]*") - str_regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - regex_fsm = make_byte_level_better_fsm(str_regex_fsm, keep_utf8=True) - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, merge_symbols(transform("πŸ˜‚")), regex_fsm.initial, full_match=True - ) - ) - assert res[-1:] == (1,) - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, - merge_symbols(transform("πŸ˜‚πŸ˜‚")), - regex_fsm.initial, - full_match=False, - ) - ) - assert res[-1:] == (1,) - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, merge_symbols(transform("!")), regex_fsm.initial, full_match=True - ) - ) - assert res == tuple() - - res = tuple( - walk_fsm_from_token_str_rust( - regex_fsm, - merge_symbols(transform("πŸ˜‚πŸ˜‚")), - regex_fsm.initial, - full_match=True, - ) - ) - assert res == tuple() - - -def test_create_fsm_index_end_to_end(): - regex_str = "0|[1-9][0-9]*" - - regex_pattern = interegular.parse_pattern(regex_str) - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - - vocabulary = { - "blah": [0], - "1a": [1], - "2": [2], - "0": [3], - "": [4], - } - - vocabulary_nb = [] - for token_tuple, token_ids in vocabulary.items(): - token = merge_symbols(token_tuple) - token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) - vocabulary_nb.append((token, token_ids_np)) - - res = create_fsm_index_end_to_end( - regex_fsm.fsm_info, - vocabulary_nb, - frozenset(), - ) - - assert res == {0: {2: 2, 3: 1}, 2: {2: 2, 3: 2}} - - -def test_create_fsm_index_end_to_end_multi_byte(): - regex_str = "πŸ˜‡| [😈-😍][πŸ˜‡-😎]*" - - regex_pattern = interegular.parse_pattern(regex_str) - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - byte_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) - - vocabulary = { - "blah": [0], - "😈a": [1], - "πŸ˜‡": [2], - "😍": [3], - merge_symbols(("F0", "9F", "98", "8D")): [4], # '😍' - " 😍": [5], - merge_symbols((" ", "F0", "9F", "98", "8D")): [6], # ' 😍' - merge_symbols((" ", "F0", "9F", "98")): [7], # ' 😍' incomplete - "": [8], - } - - vocabulary_nb = [] - for token_tuple, token_ids in vocabulary.items(): - token_tuple_np = merge_symbols(token_tuple) - token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) - vocabulary_nb.append((token_tuple_np, token_ids_np)) - - res = create_fsm_index_end_to_end( - byte_fsm.fsm_info, - vocabulary_nb, - frozenset(), - ) - - assert res == {0: {5: 3, 6: 3, 7: 7, 2: 2}, 3: {2: 3, 3: 3, 4: 3}} - - -@pytest.mark.parametrize( - "hf_tokenizer_uri", - [ - "gpt2", - "microsoft/phi-2", - "Qwen/Qwen1.5-0.5B-Chat", - "NousResearch/Hermes-2-Pro-Llama-3-8B", - ], -) -def test_create_fsm_index_tokenizer(hf_tokenizer_uri): - # The combined regular expressions of a lexer state in a Python grammar - regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" - - regex_pattern = interegular.parse_pattern(regex_str) - # Not reduced, so that there are many states - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm()) - bytes_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) - - num_fsm_states = len(regex_fsm.states) - assert num_fsm_states == 220 - - num_bytes_fsm_states = len(bytes_fsm.states) - assert num_bytes_fsm_states == 235 - - tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri) - tokenizer = TransformerTokenizer(tokenizer) - - states_to_token_subsets, empty_token_ids = create_fsm_index_tokenizer( - bytes_fsm, tokenizer - ) - - assert not empty_token_ids - assert len(states_to_token_subsets) / num_fsm_states > 0.94 - - -@pytest.mark.parametrize( - "regex,string,should_accept", - [ - ("[a-c]+", "πŸ˜€", False), - ("[^a-c]+", "πŸ˜€", True), - ("πŸ˜€+", "πŸ˜€πŸ˜€πŸ˜€", True), - ("πŸ˜€+", "a", False), - ("[πŸ˜€-😍]{2}", "😈😈", True), - ("[πŸ˜€-😍]{2}", "aa", False), - ("[^πŸ˜€-😍]{2}", "aa", True), - ("[^πŸ˜€-😍]{2}", "😈😈", False), - ("[^πŸ˜€-😍]{2}", "😎😎", True), - ("[^πŸ˜€-😍]{2}", "πŸ˜ŽπŸ˜“", True), - ("[^πŸ˜€-😍]{2}", "😎😈", False), - ("[πŸ˜€-πŸ™Œ]{2}", "😎😈", True), - ("[^πŸ˜€-πŸ™Œ]{2}", "😎😈", False), - ("[^πŸ˜€-πŸ™Œ]{2}", "πŸ™πŸ™", True), - ("[^πŸ˜€-πŸ™Œ]{2}", "πŸ™πŸ˜Ž", False), - ], -) -def test_make_byte_level_fsm(regex, string, should_accept): - str_fsm = interegular.parse_pattern(regex).to_fsm() - str_accepts = str_fsm.accepts(string) - assert str_accepts == should_accept - - byte_fsm = make_byte_level_fsm(str_fsm) - byte_accepts = byte_fsm.accepts(to_bytes(string)) # type: ignore - assert byte_accepts == str_accepts - - mix_fsm = make_byte_level_fsm(str_fsm, keep_utf8=True) - mix_accepts = mix_fsm.accepts(to_bytes(string)) # type: ignore - assert mix_accepts == str_accepts - - mix_accepts_utf8 = mix_fsm.accepts(string) # type: ignore - assert mix_accepts_utf8 == str_accepts - - def advance(fsm, state, seq): - for symbol in seq: - if state is None: - return None - key = fsm.alphabet[symbol] - state = fsm.map[state].get(key) - return state - - # verify each state along the pattern - str_state = str_fsm.initial - byte_state = byte_fsm.initial - mix_state = byte_fsm.initial - for symbol in string: - str_state = advance(str_fsm, str_state, symbol) - byte_state = advance(byte_fsm, byte_state, to_bytes(symbol)) - mix_state_utf8 = advance(mix_fsm, mix_state, symbol) - mix_state = advance(mix_fsm, mix_state, to_bytes(symbol)) - assert byte_state == str_state - assert mix_state == str_state - assert mix_state_utf8 == str_state - - -@pytest.mark.skip(reason="Only for local profiling") -def test_regex_index_performance(): - from line_profiler import LineProfiler # type: ignore [import] - - regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" - - regex_pattern = interegular.parse_pattern(regex_str) - # Not reduced, so that there are many states - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm()) - - num_fsm_states = len(regex_fsm.states) - assert num_fsm_states == 220 - - tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer = TransformerTokenizer(tokenizer) - - res, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer) - assert len(res) > 1 - - profiler = LineProfiler(create_fsm_index_end_to_end) - - profiler.runctx( - "create_fsm_index_tokenizer(regex_fsm, tokenizer)", - globals(), - locals(), - ) - profiler.dump_stats("line-profiler-create_fsm_index.pkl") - profiler.print_stats(output_unit=1e-3, summarize=True, stripzeros=True) - - -@pytest.mark.skip(reason="Only for local profiling") -def test_json_index_performance(): - import json - from enum import Enum - - import outlines_core - from line_profiler import LineProfiler # type: ignore [import] - from pydantic import BaseModel, constr - - class Weapon(str, Enum): - sword = "sword" - axe = "axe" - mace = "mace" - spear = "spear" - bow = "bow" - crossbow = "crossbow" - - class Armor(str, Enum): - leather = "leather" - chainmail = "chainmail" - plate = "plate" - - class Character(BaseModel): - name: constr(max_length=10) - # TODO: Add support for conint - age: int # conint(int, ge=18, le=100) - armor: Armor - weapon: Weapon - # TODO: Add support for conint - strength: int # conint(int, ge=0, le=100) - - model = outlines_core.models.transformers("gpt2", device="cuda") - json_schema = json.dumps(Character.model_json_schema()) - - def build_regex(): - regex_str = outlines_core.index.json_schema.build_regex_from_object(json_schema) - outlines_core.generate.regex(model, regex_str) - - profiler = LineProfiler(create_fsm_index_end_to_end) - profiler.add_function(create_fsm_index_tokenizer) - profiler.add_function(outlines_core.index.index.RegexFSM.__init__) - - profiler.runctx( - "build_regex()", - globals(), - locals(), - ) - profiler.dump_stats("line-profiler-build-json-regex.pkl") - profiler.print_stats(output_unit=1e-3, summarize=True, stripzeros=True) - - -def test_token_trans_keys_identical(): - """assert two tokens w/ identical behavior wrt FSM have same trans key seq""" - - class MockTokenizer: - vocabulary = {"a": 1, "b": 2, "z": 3, "eos": 4} - special_tokens = {"eos"} - eos_token_id = 4 - - def convert_token_to_string(self, token): - return token - - tokenizer = MockTokenizer() - - pattern = r"z[ab]z" - regex_pattern = interegular.parse_pattern(pattern) - interegular_fsm = regex_pattern.to_fsm().reduce() - regex_fsm, _ = make_deterministic_fsm(interegular_fsm) - vocabulary, _ = reduced_vocabulary(tokenizer) - token_trans_keys = get_vocabulary_transition_keys( - regex_fsm.fsm_info.alphabet_symbol_mapping, - regex_fsm.fsm_info.alphabet_anything_value, - list(vocabulary.items()), - frozenset(), - ) - - token_str_to_tranition_keys = { - token_str: trans_key_seq - for (token_str, _), trans_key_seq in zip(vocabulary.items(), token_trans_keys) - } - # `a` and `b` both are workable, but `z` has distinct transition rules - assert interegular_fsm.accepts("zaz") - assert interegular_fsm.accepts("zbz") - assert token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["b"] - assert not token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["z"] - - -def test_token_trans_keys_walk_fsm(): - """assert _walk_fsm works using transition keys""" - - class MockTokenizer: - vocabulary = {"ab": 1, "ac": 2, "az": 3, "eos": 4} - special_tokens = {"eos"} - eos_token_id = 4 - - def convert_token_to_string(self, token): - return token - - tokenizer = MockTokenizer() - - pattern = r"a[bc]z" - regex_pattern = interegular.parse_pattern(pattern) - interegular_fsm = regex_pattern.to_fsm().reduce() - regex_fsm, _ = make_deterministic_fsm(interegular_fsm) - vocabulary, _ = reduced_vocabulary(tokenizer) - token_trans_keys = get_vocabulary_transition_keys( - regex_fsm.fsm_info.alphabet_symbol_mapping, - regex_fsm.fsm_info.alphabet_anything_value, - list(vocabulary.items()), - frozenset(), - ) - - token_str_trans_key_seq = { - token_str: trans_key_seq - for (token_str, _), trans_key_seq in zip(vocabulary.items(), token_trans_keys) - } - - # verify initial state valid only for "ab" and "ac" using transition key seq - token_acceptance = {"ab": True, "ac": True, "az": False} - for token, should_accept in token_acceptance.items(): - token_trans_key_seq = token_str_trans_key_seq[token] - state_seq = _walk_fsm( - regex_fsm.fsm_info.transitions, - regex_fsm.fsm_info.initial, - regex_fsm.fsm_info.finals, - token_trans_key_seq, - regex_fsm.initial, - False, - ) - is_accepted = len(state_seq) >= len(token_trans_key_seq) - assert should_accept == is_accepted - - -@pytest.mark.parametrize( - "rare_token", - [ - "οΏ½", - "οΏ½οΏ½", - "οΏ½.", - "οΏ½..", - "▁�", - "▁▁�", - "▁�.", - "▁�.", - "▁▁�..", - ], -) -def test_reduced_vocabulary_with_rare_tokens(rare_token): - """Assert reduced_vocabulary works with rare tokens. - - See [1] and [2] for context. - - [1]: https://github.com/outlines-dev/outlines/pull/763 - [2]: https://github.com/outlines-dev/outlines/pull/948 - """ - tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - tokenizer = adapt_tokenizer(tokenizer=tokenizer) - tokenizer.vocabulary[rare_token] = max(tokenizer.vocabulary.values()) + 1 - reduced_vocabulary(tokenizer) diff --git a/tests/fsm/test_types.py b/tests/fsm/test_types.py deleted file mode 100644 index fc66bd3f..00000000 --- a/tests/fsm/test_types.py +++ /dev/null @@ -1,28 +0,0 @@ -import datetime - -import pytest -from outlines_core.fsm.types import ( - BOOLEAN, - DATE, - DATETIME, - FLOAT, - INTEGER, - TIME, - python_types_to_regex, -) - - -@pytest.mark.parametrize( - "python_type,regex", - [ - (int, INTEGER), - (float, FLOAT), - (bool, BOOLEAN), - (datetime.date, DATE), - (datetime.time, TIME), - (datetime.datetime, DATETIME), - ], -) -def test_python_types(python_type, regex): - test_regex, _ = python_types_to_regex(python_type) - assert regex == test_regex diff --git a/tests/models/test_tokenizer.py b/tests/models/test_tokenizer.py deleted file mode 100644 index 9457bda5..00000000 --- a/tests/models/test_tokenizer.py +++ /dev/null @@ -1,7 +0,0 @@ -import pytest -from outlines_core.models.tokenizer import Tokenizer - - -def test_tokenizer(): - with pytest.raises(TypeError, match="instantiate abstract"): - Tokenizer() diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py deleted file mode 100644 index 799f7a5b..00000000 --- a/tests/models/test_transformers.py +++ /dev/null @@ -1,116 +0,0 @@ -import pytest -import torch -from outlines_core.models.transformers import TransformerTokenizer, transformers -from transformers import AutoTokenizer -from transformers.models.gpt2 import GPT2TokenizerFast - -TEST_MODEL = "hf-internal-testing/tiny-random-GPTJForCausalLM" - - -def test_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL, padding_side="left") - tokenizer = TransformerTokenizer(tokenizer) - assert tokenizer.eos_token_id == 0 - assert tokenizer.pad_token_id == 0 - assert isinstance(tokenizer.tokenizer, GPT2TokenizerFast) - - token_ids, attention_mask = tokenizer.encode("Test") - assert token_ids.ndim == 2 - assert token_ids.shape[0] == 1 - assert isinstance(token_ids, torch.LongTensor) - assert token_ids.shape == attention_mask.shape - - token_ids, attention_mask = tokenizer.encode(["Test", "Test"]) - assert token_ids.ndim == 2 - assert token_ids.shape[0] == 2 - assert isinstance(token_ids, torch.LongTensor) - assert token_ids.shape == attention_mask.shape - - token_ids, attention_mask = tokenizer.encode(["Test", "A long sentence"]) - assert token_ids.shape == attention_mask.shape - assert attention_mask[0][0] == tokenizer.pad_token_id - - text = tokenizer.decode(torch.tensor([[0, 1, 2]])) - isinstance(text, str) - - text = tokenizer.decode(torch.tensor([[0, 1, 2], [3, 4, 5]])) - isinstance(text, list) - isinstance(text[0], str) - isinstance(text[1], str) - - tokenizer = AutoTokenizer.from_pretrained( - TEST_MODEL, additional_special_tokens=["", ""] - ) - tokenizer = TransformerTokenizer(tokenizer) - assert "" in tokenizer.special_tokens - assert "" in tokenizer.special_tokens - - -def test_llama_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - tokenizer = TransformerTokenizer(tokenizer) - - # Broken - assert tokenizer.tokenizer.convert_tokens_to_string(["▁baz"]) == "baz" - assert tokenizer.tokenizer.convert_tokens_to_string(["<0x20>"]) == "" - assert tokenizer.tokenizer.convert_tokens_to_string(["▁▁▁"]) == " " - - # Not broken - assert tokenizer.convert_token_to_string("▁baz") == " baz" - assert tokenizer.convert_token_to_string("<0x20>") == " " - assert tokenizer.convert_token_to_string("▁▁▁") == " " - - -def test_model(): - model = transformers(TEST_MODEL, device="cpu") - assert isinstance(model.tokenizer, TransformerTokenizer) - assert model.model.device.type == "cpu" - - model = transformers(TEST_MODEL, model_kwargs={"device_map": "cpu"}) - assert isinstance(model.tokenizer, TransformerTokenizer) - assert model.model.device.type == "cpu" - - model = transformers(TEST_MODEL, device="cpu", model_kwargs={"device_map": "cuda"}) - assert isinstance(model.tokenizer, TransformerTokenizer) - assert model.model.device.type == "cpu" - - input_ids = torch.tensor([[0, 1, 2]]) - logits, kv_cache = model(input_ids, torch.ones_like(input_ids)) - assert logits.type() == "torch.FloatTensor" - assert logits.ndim == 2 - assert logits.shape[0] == 1 - assert len(kv_cache) == model.model.config.n_layer - assert len(kv_cache[0]) == 2 - assert kv_cache[0][0].shape[1] == model.model.config.n_head - assert kv_cache[0][0].shape[2] == 3 # number of tokens - - input_ids = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) - logits, kv_cache = model(input_ids, torch.ones_like(input_ids)) - assert logits.type() == "torch.FloatTensor" - assert logits.ndim == 2 - assert logits.shape[0] == 3 - assert len(kv_cache) == model.model.config.n_layer - assert len(kv_cache[0]) == 2 - assert kv_cache[0][0].shape[1] == model.model.config.n_head - assert kv_cache[0][0].shape[2] == 3 # number of tokens - - with pytest.raises(AssertionError): - input_ids = torch.tensor([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [0, 1, 2]]]) - logits = model(input_ids, torch.ones_like(input_ids)) - - -def test_tokenizer_eq_hash(): - tokenizer_hf = AutoTokenizer.from_pretrained("gpt2") - - tokenizer = TransformerTokenizer(tokenizer_hf) - tokenizer_2 = TransformerTokenizer(tokenizer_hf) - - assert tokenizer == tokenizer_2 - assert hash(tokenizer) == hash(tokenizer_2) - - tokenizer_hf_2 = AutoTokenizer.from_pretrained("gpt2") - tokenizer_hf_2.add_tokens(["test_token"]) - - tokenizer_3 = TransformerTokenizer(tokenizer_hf_2) - assert tokenizer != tokenizer_3 - assert hash(tokenizer) != hash(tokenizer_3)