Skip to content

Commit

Permalink
Merge pull request #245 from WorksApplications/feature/arseny/mode-li…
Browse files Browse the repository at this point in the history
…terals

Allow string literals as segmentation modes
  • Loading branch information
mh-northlander authored Mar 26, 2024
2 parents 533c3ac + e850158 commit 693a32c
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 57 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ README*.html
python/dist/
__pycache__/
.env
.venv
*.egg-info
*.so
python/py_src/sudachipy/*.pyd
Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ crate-type = ["cdylib"]
[dependencies]
pyo3 = { version = "0.20", features = ["extension-module"] }
thread_local = "1.1" # Apache 2.0/MIT
scopeguard = "1" # Apache 2.0/MIT

[dependencies.sudachi]
path = "../sudachi"
25 changes: 19 additions & 6 deletions python/py_src/sudachipy/sudachipy.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import ClassVar, Iterator, List, Tuple, Union, Callable, Iterable, Optional, Literal, Set
from sudachipy.config import Config
from .config import Config

POS = Tuple[str, str, str, str, str, str]
# POS element
Expand Down Expand Up @@ -32,7 +32,12 @@ class SplitMode:
B: ClassVar[SplitMode] = ...
C: ClassVar[SplitMode] = ...
@classmethod
def __init__(cls) -> None: ...
def __init__(cls, mode: str = "C") -> None:
"""
Creates a split mode from a string value
:param mode: string representation of the split mode
"""
...


class Dictionary:
Expand Down Expand Up @@ -65,7 +70,7 @@ class Dictionary:
...

def create(self,
mode: SplitMode = SplitMode.C,
mode: Union[SplitMode, Literal["A", "B", "C"]] = SplitMode.C,
fields: FieldSet = None,
*,
projection: str = None) -> Tokenizer:
Expand Down Expand Up @@ -96,7 +101,7 @@ class Dictionary:
...

def pre_tokenizer(self,
mode: SplitMode = SplitMode.C,
mode: Union[SplitMode, Literal["A", "B", "C"]] = "C",
fields: FieldSet = None,
handler: Optional[Callable[[int, object, MorphemeList], list]] = None,
*,
Expand Down Expand Up @@ -191,7 +196,7 @@ class Morpheme:
"""
...

def split(self, mode: SplitMode, out: Optional[MorphemeList] = None, add_single: bool = True) -> MorphemeList:
def split(self, mode: Union[SplitMode, Literal["A", "B", "C"]], out: Optional[MorphemeList] = None, add_single: bool = True) -> MorphemeList:
"""
Returns sub-morphemes in the provided split mode.
Expand Down Expand Up @@ -278,7 +283,7 @@ class Tokenizer:
def __init__(cls) -> None: ...

def tokenize(self, text: str,
mode: SplitMode = ...,
mode: Union[SplitMode, Literal["A", "B", "C"]] = ...,
out: Optional[MorphemeList] = None) -> MorphemeList:
"""
Break text into morphemes.
Expand All @@ -295,6 +300,14 @@ class Tokenizer:
"""
...

@property
def mode(self) -> SplitMode:
"""
Get the current analysis mode
:return: current analysis mode
"""
...


class WordInfo:
a_unit_split: ClassVar[List[int]] = ...
Expand Down
39 changes: 30 additions & 9 deletions python/src/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ use std::convert::TryFrom;
use std::fmt::Write;
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::Arc;
use sudachi::analysis::Mode;

use crate::errors::{wrap, wrap_ctx, SudachiError as SudachiErr};
use sudachi::analysis::stateless_tokenizer::DictionaryAccess;
Expand Down Expand Up @@ -218,16 +220,20 @@ impl PyDictionary {
/// :param fields: load only a subset of fields.
/// See https://worksapplications.github.io/sudachi.rs/python/topics/subsetting.html
#[pyo3(
text_signature = "($self, mode: sudachipy.SplitMode = sudachipy.SplitMode.C) -> sudachipy.Tokenizer",
text_signature = "($self, mode = 'C') -> sudachipy.Tokenizer",
signature = (mode = None, fields = None, *, projection = None)
)]
fn create(
&self,
mode: Option<PySplitMode>,
fields: Option<&PySet>,
projection: Option<&PyString>,
fn create<'py>(
&'py self,
py: Python<'py>,
mode: Option<&'py PyAny>,
fields: Option<&'py PySet>,
projection: Option<&'py PyString>,
) -> PyResult<PyTokenizer> {
let mode = mode.unwrap_or(PySplitMode::C).into();
let mode = match mode {
Some(m) => extract_mode(py, m)?,
None => Mode::C,
};
let fields = parse_field_subset(fields)?;
let mut required_fields = self.config.projection.required_subset();
let dict = self.dictionary.as_ref().unwrap().clone();
Expand Down Expand Up @@ -283,12 +289,15 @@ impl PyDictionary {
fn pre_tokenizer<'p>(
&'p self,
py: Python<'p>,
mode: Option<PySplitMode>,
mode: Option<&PyAny>,
fields: Option<&PySet>,
handler: Option<Py<PyAny>>,
projection: Option<&PyString>,
) -> PyResult<&'p PyAny> {
let mode = mode.unwrap_or(PySplitMode::C).into();
let mode = match mode {
Some(m) => extract_mode(py, m)?,
None => Mode::C,
};
let subset = parse_field_subset(fields)?;
if let Some(h) = handler.as_ref() {
if !h.as_ref(py).is_callable() {
Expand Down Expand Up @@ -401,6 +410,18 @@ fn config_repr(cfg: &Config) -> Result<String, std::fmt::Error> {
Ok(result)
}

pub(crate) fn extract_mode<'py>(py: Python<'py>, mode: &'py PyAny) -> PyResult<Mode> {
if mode.is_instance_of::<PyString>() {
let mode = mode.str()?.to_str()?;
Mode::from_str(mode).map_err(|e| SudachiErr::new_err(e).into())
} else if mode.is_instance_of::<PySplitMode>() {
let mode = mode.extract::<PySplitMode>()?;
Ok(Mode::from(mode))
} else {
Err(SudachiErr::new_err(("unknown mode", mode.into_py(py))))
}
}

fn read_config_from_fs(path: Option<&Path>) -> PyResult<ConfigBuilder> {
wrap(ConfigBuilder::from_opt_file(path))
}
Expand Down
9 changes: 5 additions & 4 deletions python/src/morpheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ use pyo3::types::{PyList, PyString, PyTuple, PyType};

use sudachi::prelude::{Morpheme, MorphemeList};

use crate::dictionary::{PyDicData, PyDictionary};
use crate::dictionary::{extract_mode, PyDicData, PyDictionary};
use crate::projection::MorphemeProjection;
use crate::tokenizer::PySplitMode;
use crate::word_info::PyWordInfo;

pub(crate) type PyMorphemeList = MorphemeList<Arc<PyDicData>>;
Expand Down Expand Up @@ -362,12 +361,14 @@ impl PyMorpheme {
fn split<'py>(
&'py self,
py: Python<'py>,
mode: PySplitMode,
mode: &PyAny,
out: Option<&'py PyCell<PyMorphemeListWrapper>>,
add_single: Option<bool>,
) -> PyResult<&'py PyCell<PyMorphemeListWrapper>> {
let list = self.list(py);

let mode = extract_mode(py, mode)?;

let out_cell = match out {
None => {
let list = list.empty_clone(py);
Expand All @@ -385,7 +386,7 @@ impl PyMorpheme {
out_ref.clear();
let splitted = list
.internal(py)
.split_into(mode.into(), self.index, out_ref)
.split_into(mode, self.index, out_ref)
.map_err(|e| {
PyException::new_err(format!("Error while splitting morpheme: {}", e.to_string()))
})?;
Expand Down
100 changes: 64 additions & 36 deletions python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@
* limitations under the License.
*/

use std::ops::DerefMut;
use std::str::FromStr;
use std::sync::Arc;

use pyo3::exceptions::PyException;
use pyo3::prelude::*;

use sudachi::analysis::stateful_tokenizer::StatefulTokenizer;

use sudachi::dic::subset::InfoSubset;
use sudachi::prelude::*;

use crate::dictionary::PyDicData;
use crate::dictionary::{extract_mode, PyDicData};
use crate::errors::SudachiError as SudachiPyErr;
use crate::morpheme::{PyMorphemeListWrapper, PyProjector};

/// Unit to split text
Expand All @@ -35,33 +37,47 @@ use crate::morpheme::{PyMorphemeListWrapper, PyProjector};
///
/// C == long mode
//
// This implementation is a workaround. Waiting for the pyo3 enum feature.
// ref: [PyO3 issue #834](https://github.com/PyO3/pyo3/issues/834).
#[pyclass(module = "sudachipy.tokenizer", name = "SplitMode")]
#[derive(Clone, PartialEq, Eq)]
#[repr(transparent)]
pub struct PySplitMode {
mode: u8,
}

#[pymethods]
impl PySplitMode {
#[classattr]
pub const A: Self = Self { mode: 0 };

#[classattr]
pub const B: Self = Self { mode: 1 };

#[classattr]
pub const C: Self = Self { mode: 2 };
#[pyclass(module = "sudachipy.tokenizer", name = "SplitMode", frozen)]
#[derive(Clone, PartialEq, Eq, Copy, Debug)]
#[repr(u8)]
pub enum PySplitMode {
A,
B,
C,
}

impl From<PySplitMode> for Mode {
fn from(mode: PySplitMode) -> Self {
match mode {
PySplitMode::A => Mode::A,
PySplitMode::B => Mode::B,
_ => Mode::C,
PySplitMode::C => Mode::C,
}
}
}

impl From<Mode> for PySplitMode {
fn from(value: Mode) -> Self {
match value {
Mode::A => PySplitMode::A,
Mode::B => PySplitMode::B,
Mode::C => PySplitMode::C,
}
}
}

#[pymethods]
impl PySplitMode {
#[new]
fn new(mode: Option<&str>) -> PyResult<PySplitMode> {
let mode = match mode {
Some(m) => m,
None => return Ok(PySplitMode::C),
};

match Mode::from_str(mode) {
Ok(m) => Ok(m.into()),
Err(e) => Err(SudachiPyErr::new_err(e.to_string())),
}
}
}
Expand Down Expand Up @@ -112,29 +128,39 @@ impl PyTokenizer {
/// :type mode: sudachipy.SplitMode
/// :type out: sudachipy.MorphemeList
#[pyo3(
text_signature = "($self, text: str, mode: SplitMode = None, logger = None, out = None) -> sudachipy.MorphemeList",
text_signature = "($self, text: str, mode = None, logger = None, out = None) -> sudachipy.MorphemeList",
signature = (text, mode = None, logger = None, out = None)
)]
#[allow(unused_variables)]
fn tokenize<'py>(
&'py mut self,
py: Python<'py>,
text: &'py str,
mode: Option<PySplitMode>,
mode: Option<&PyAny>,
logger: Option<PyObject>,
out: Option<&'py PyCell<PyMorphemeListWrapper>>,
) -> PyResult<&'py PyCell<PyMorphemeListWrapper>> {
// keep default mode to restore later
// restore default mode on scope exit
let mode = match mode {
None => None,
Some(m) => Some(extract_mode(py, m)?),
};
let default_mode = mode.map(|m| self.tokenizer.set_mode(m.into()));
let mut tokenizer = scopeguard::guard(&mut self.tokenizer, |t| {
default_mode.map(|m| t.set_mode(m));
});

// analysis can be done without GIL
let err = py.allow_threads(|| {
tokenizer.reset().push_str(text);
tokenizer.do_tokenize()
});

self.tokenizer.reset().push_str(text);
self.tokenizer
.do_tokenize()
.map_err(|e| PyException::new_err(format!("Tokenization error: {}", e.to_string())))?;
err.map_err(|e| SudachiPyErr::new_err(format!("Tokenization error: {}", e.to_string())))?;

let out_list = match out {
None => {
let dict = self.tokenizer.dict_clone();
let dict = tokenizer.dict_clone();
let morphemes = MorphemeList::empty(dict);
let wrapper =
PyMorphemeListWrapper::from_components(morphemes, self.projection.clone());
Expand All @@ -146,16 +172,18 @@ impl PyTokenizer {
let mut borrow = out_list.try_borrow_mut();
let morphemes = match borrow {
Ok(ref mut ms) => ms.internal_mut(py),
Err(e) => return Err(PyException::new_err("out was used twice at the same time")),
Err(e) => return Err(SudachiPyErr::new_err("out was used twice at the same time")),
};

morphemes
.collect_results(&mut self.tokenizer)
.map_err(|e| PyException::new_err(format!("Tokenization error: {}", e.to_string())))?;

// restore default mode
default_mode.map(|m| self.tokenizer.set_mode(m));
.collect_results(tokenizer.deref_mut())
.map_err(|e| SudachiPyErr::new_err(format!("Tokenization error: {}", e.to_string())))?;

Ok(out_list)
}

#[getter]
fn mode(&self) -> PySplitMode {
self.tokenizer.mode().into()
}
}
15 changes: 15 additions & 0 deletions python/tests/test_pretokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ def test_works_with_different_split_mode(self):
res = tok.encode("外国人参政権")
self.assertEqual(res.ids, [1, 5, 2, 3])

def test_works_with_different_split_mode_str(self):
pretok = self.dict.pre_tokenizer(mode='A')
vocab = {
"[UNK]": 0,
"外国": 1,
"参政": 2,
"権": 3,
"人": 5,
"外国人参政権": 4
}
tok = tokenizers.Tokenizer(WordLevel(vocab, unk_token="[UNK]"))
tok.pre_tokenizer = pretok
res = tok.encode("外国人参政権")
self.assertEqual(res.ids, [1, 5, 2, 3])

def test_with_handler(self):
def _handler(index, sentence: tokenizers.NormalizedString, ml: MorphemeList):
return [tokenizers.NormalizedString(ml[0].part_of_speech()[0]), tokenizers.NormalizedString(str(len(ml)))]
Expand Down
Loading

0 comments on commit 693a32c

Please sign in to comment.