From 448a1e4edc4e874aca41afa01f454f04f0b92971 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Tue, 6 Oct 2020 08:10:33 +0900 Subject: [PATCH] Add type hints --- seqeval/scheme.py | 56 ++++++++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/seqeval/scheme.py b/seqeval/scheme.py index 830ce79..571395e 100644 --- a/seqeval/scheme.py +++ b/seqeval/scheme.py @@ -1,4 +1,5 @@ import enum +from typing import List, Set, Tuple, Type class Entity: @@ -11,7 +12,7 @@ def __init__(self, start: int, end: int, tag: str): def __repr__(self): return '({}, {}, {})'.format(self.tag, self.start, self.end) - def __eq__(self, other): + def __eq__(self, other: 'Entity'): return self.start == other.start and self.end == other.end and self.tag == other.tag def __hash__(self): @@ -72,18 +73,18 @@ def is_valid(self): return True def is_start(self, prev: 'Token'): - """The current token is the start of chunk.""" + """Check whether the current token is the start of chunk.""" return self.check_patterns(prev, self.start_patterns) def is_inside(self, prev: 'Token'): - """The current token is inside of chunk.""" + """Check whether the current token is inside of chunk.""" return self.check_patterns(prev, self.inside_patterns) def is_end(self, prev: 'Token'): - """The previous token is the end of chunk.""" + """Check whether the previous token is the end of chunk.""" return self.check_patterns(prev, self.end_patterns) - def check_tag(self, prev, cond): + def check_tag(self, prev: 'Token', cond: Tag): """Check whether the tag pattern is matched.""" if cond == Tag.ANY: return True @@ -93,7 +94,7 @@ def check_tag(self, prev, cond): return True return False - def check_patterns(self, prev, patterns): + def check_patterns(self, prev: 'Token', patterns: Set[Tuple[Prefix, Prefix, Tag]]): """Check whether the prefix patterns are matched.""" for prev_prefix, current_prefix, tag_cond in patterns: if prev.prefix in prev_prefix and self.prefix in current_prefix and self.check_tag(prev, tag_cond): @@ -125,6 +126,7 @@ class IOB1(Token): class IOE1(Token): + # Todo: IOE1 hasn't yet been able to handle some cases. See unit testing. allowed_prefix = Prefix.I | Prefix.O | Prefix.E start_patterns = { (Prefix.O, Prefix.I, Tag.ANY), @@ -203,9 +205,10 @@ class IOBES(Token): class Tokens: - def __init__(self, tokens, token_class): - self.tokens = [token_class(token) for token in tokens] + [token_class('O')] - self.token_class = token_class + def __init__(self, tokens: List[str], scheme: Type[Token], suffix: bool = False, delimiter: str = '-'): + self.tokens = [scheme(token, suffix=suffix, delimiter=delimiter) for token in tokens] + self.scheme = scheme + self.outside_token = scheme('O', suffix=suffix, delimiter=delimiter) @property def entities(self): @@ -220,10 +223,10 @@ def entities(self): [('PER', 0, 2), ('LOC', 3, 4)] """ i = 0 - prev = self.token_class('O') entities = [] - while i < len(self.tokens): - token = self.tokens[i] + prev = self.outside_token + while i < len(self.extended_tokens): + token = self.extended_tokens[i] token.is_valid() if token.is_start(prev): end = self._forward(start=i + 1, prev=token) @@ -233,24 +236,37 @@ def entities(self): i = end else: i += 1 - prev = self.tokens[i - 1] + prev = self.extended_tokens[i - 1] return entities - def _forward(self, start, prev): - for i, token in enumerate(self.tokens[start:], start): + def _forward(self, start: int, prev: Token): + for i, token in enumerate(self.extended_tokens[start:], start): if token.is_inside(prev): prev = token else: return i - return len(self.tokens) - 2 + return len(self.tokens) - 1 - def _is_end(self, i): - token = self.tokens[i] - prev = self.tokens[i - 1] + def _is_end(self, i: int): + token = self.extended_tokens[i] + prev = self.extended_tokens[i - 1] return token.is_end(prev) + @property + def extended_tokens(self): + # append a sentinel. + tokens = self.tokens + [self.outside_token] + return tokens + + +def auto_detect(sequences: List[List[str]], suffix: bool = False, delimiter: str = '-'): + """Detects scheme automatically. -def auto_detect(sequences, suffix=False, delimiter='-'): + auto_detect supports the following schemes: + - IOB2 + - IOE2 + - IOBES + """ prefixes = set() error_message = 'This scheme is not supported: {}' for tokens in sequences: