Skip to content

Commit

Permalink
Add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
Hironsan committed Oct 5, 2020
1 parent 638209d commit 448a1e4
Showing 1 changed file with 36 additions and 20 deletions.
56 changes: 36 additions & 20 deletions seqeval/scheme.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import enum
from typing import List, Set, Tuple, Type


class Entity:
@@ -11,7 +12,7 @@ def __init__(self, start: int, end: int, tag: str):
def __repr__(self):
return '({}, {}, {})'.format(self.tag, self.start, self.end)

def __eq__(self, other):
def __eq__(self, other: 'Entity'):
return self.start == other.start and self.end == other.end and self.tag == other.tag

def __hash__(self):
@@ -72,18 +73,18 @@ def is_valid(self):
return True

def is_start(self, prev: 'Token'):
"""The current token is the start of chunk."""
"""Check whether the current token is the start of chunk."""
return self.check_patterns(prev, self.start_patterns)

def is_inside(self, prev: 'Token'):
"""The current token is inside of chunk."""
"""Check whether the current token is inside of chunk."""
return self.check_patterns(prev, self.inside_patterns)

def is_end(self, prev: 'Token'):
"""The previous token is the end of chunk."""
"""Check whether the previous token is the end of chunk."""
return self.check_patterns(prev, self.end_patterns)

def check_tag(self, prev, cond):
def check_tag(self, prev: 'Token', cond: Tag):
"""Check whether the tag pattern is matched."""
if cond == Tag.ANY:
return True
@@ -93,7 +94,7 @@ def check_tag(self, prev, cond):
return True
return False

def check_patterns(self, prev, patterns):
def check_patterns(self, prev: 'Token', patterns: Set[Tuple[Prefix, Prefix, Tag]]):
"""Check whether the prefix patterns are matched."""
for prev_prefix, current_prefix, tag_cond in patterns:
if prev.prefix in prev_prefix and self.prefix in current_prefix and self.check_tag(prev, tag_cond):
@@ -125,6 +126,7 @@ class IOB1(Token):


class IOE1(Token):
# Todo: IOE1 hasn't yet been able to handle some cases. See unit testing.
allowed_prefix = Prefix.I | Prefix.O | Prefix.E
start_patterns = {
(Prefix.O, Prefix.I, Tag.ANY),
@@ -203,9 +205,10 @@ class IOBES(Token):

class Tokens:

def __init__(self, tokens, token_class):
self.tokens = [token_class(token) for token in tokens] + [token_class('O')]
self.token_class = token_class
def __init__(self, tokens: List[str], scheme: Type[Token], suffix: bool = False, delimiter: str = '-'):
self.tokens = [scheme(token, suffix=suffix, delimiter=delimiter) for token in tokens]
self.scheme = scheme
self.outside_token = scheme('O', suffix=suffix, delimiter=delimiter)

@property
def entities(self):
@@ -220,10 +223,10 @@ def entities(self):
[('PER', 0, 2), ('LOC', 3, 4)]
"""
i = 0
prev = self.token_class('O')
entities = []
while i < len(self.tokens):
token = self.tokens[i]
prev = self.outside_token
while i < len(self.extended_tokens):
token = self.extended_tokens[i]
token.is_valid()
if token.is_start(prev):
end = self._forward(start=i + 1, prev=token)
@@ -233,24 +236,37 @@ def entities(self):
i = end
else:
i += 1
prev = self.tokens[i - 1]
prev = self.extended_tokens[i - 1]
return entities

def _forward(self, start, prev):
for i, token in enumerate(self.tokens[start:], start):
def _forward(self, start: int, prev: Token):
for i, token in enumerate(self.extended_tokens[start:], start):
if token.is_inside(prev):
prev = token
else:
return i
return len(self.tokens) - 2
return len(self.tokens) - 1

def _is_end(self, i):
token = self.tokens[i]
prev = self.tokens[i - 1]
def _is_end(self, i: int):
token = self.extended_tokens[i]
prev = self.extended_tokens[i - 1]
return token.is_end(prev)

@property
def extended_tokens(self):
# append a sentinel.
tokens = self.tokens + [self.outside_token]
return tokens


def auto_detect(sequences: List[List[str]], suffix: bool = False, delimiter: str = '-'):
"""Detects scheme automatically.
def auto_detect(sequences, suffix=False, delimiter='-'):
auto_detect supports the following schemes:
- IOB2
- IOE2
- IOBES
"""
prefixes = set()
error_message = 'This scheme is not supported: {}'
for tokens in sequences:

0 comments on commit 448a1e4

Please sign in to comment.