From 4f561606d5ab2d4e56fc76bfb7aa0fe176a173e9 Mon Sep 17 00:00:00 2001 From: nevesnunes <> Date: Wed, 5 Aug 2020 20:02:26 +0100 Subject: [PATCH] Add support for bytes input --- suffix_trees/STree.py | 95 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 75 insertions(+), 20 deletions(-) diff --git a/suffix_trees/STree.py b/suffix_trees/STree.py index 7f5a67f..27591d6 100644 --- a/suffix_trees/STree.py +++ b/suffix_trees/STree.py @@ -1,3 +1,5 @@ +import re + class STree(): """Class representing the suffix tree.""" @@ -7,6 +9,9 @@ def __init__(self, input=''): self.root.idx = 0 self.root.parent = self.root self.root._add_suffix_link(self.root) + self.input_type = None + self.terminal_symbol_length = 1 + self.terminal_symbols = set() if not input == '': self.build(input) @@ -17,12 +22,16 @@ def _check_input(self, input): In case of an invalid input throws ValueError. """ if isinstance(input, str): - return 'st' + return (str, 'st') + elif isinstance(input, bytes): + return (bytes, 'st') elif isinstance(input, list): if all(isinstance(item, str) for item in input): - return 'gst' + return (str, 'gst') + elif all(isinstance(item, bytes) for item in input): + return (bytes, 'gst') - raise ValueError("String argument should be of type String or a list of strings") + raise ValueError("String argument should be of type 'Bytes' or 'String' or a list of those types") def build(self, x): """Builds the Suffix tree on the given input. @@ -31,12 +40,17 @@ def build(self, x): :param x: String or List of Strings """ - type = self._check_input(x) - - if type == 'st': - x += next(self._terminalSymbolsGenerator()) + (input_type, tree_type) = self._check_input(x) + self.input_type = input_type + if tree_type == 'st': + terminal_symbol = next(self._terminalSymbolsGenerator()) + if self.input_type == bytes and isinstance(terminal_symbol, str): + terminal_symbol = bytes(terminal_symbol, 'UTF8') + self.terminal_symbol_length = len(terminal_symbol) + self.terminal_symbols.add(terminal_symbol) + x += terminal_symbol self._build(x) - if type == 'gst': + elif tree_type == 'gst': self._build_generalized(x) def _build(self, x): @@ -54,7 +68,7 @@ def _build_McCreight(self, x): """ u = self.root d = 0 - for i in range(len(x)): + for i in range(len(x) - (self.terminal_symbol_length - 1)): while u.depth == d and u._has_transition(x[d + i]): u = u._get_transition_link(x[d + i]) d = d + 1 @@ -73,7 +87,7 @@ def _build_McCreight(self, x): def _create_node(self, x, u, d): i = u.idx p = u.parent - v = _SNode(idx=i, depth=d) + v = _SNode(idx=i, depth=d, input_type=self.input_type) v._add_transition_link(u, x[i + d]) u.parent = v p._add_transition_link(v, x[i + p.depth]) @@ -81,7 +95,7 @@ def _create_node(self, x, u, d): return v def _create_leaf(self, x, i, u, d): - w = _SNode() + w = _SNode(input_type=self.input_type) w.idx = i w.depth = len(x) - i u._add_transition_link(w, x[i + d]) @@ -110,8 +124,17 @@ def _build_generalized(self, xs): """Builds a Generalized Suffix Tree (GST) from the array of strings provided. """ terminal_gen = self._terminalSymbolsGenerator() - - _xs = ''.join([x + next(terminal_gen) for x in xs]) + _xs = None + for x in xs: + terminal_symbol = next(terminal_gen) + if self.input_type == bytes and isinstance(terminal_symbol, str): + terminal_symbol = bytes(terminal_symbol, 'UTF8') + self.terminal_symbol_length = len(terminal_symbol) + self.terminal_symbols.add(terminal_symbol) + if not _xs: + _xs = x + terminal_symbol + else: + _xs += x + terminal_symbol self.word = _xs self._generalized_word_starts(xs) self._build(_xs) @@ -138,6 +161,19 @@ def _get_word_start_index(self, idx): i += 1 return i + def _suffix_contains_terminal_symbol(self, start, end): + """Validates if suffix was composed with multi-byte terminal symbol""" + for i in range(0, self.terminal_symbol_length): + if end+i <= len(self.word): + candidate_substring = None + if end-start == self.terminal_symbol_length: + candidate_substring = self.word[start+i+1:end+i+1] + else: + candidate_substring = self.word[start+i:end+i+1] + if candidate_substring in self.terminal_symbols: + return True + return False + def lcs(self, stringIdxs=-1): """Returns the Largest Common Substring of Strings provided in stringIdxs. If stringIdxs is not provided, the LCS of all strings is returned. @@ -163,8 +199,15 @@ def _find_lcs(self, node, stringIdxs): if nodes == []: return node - deepestNode = max(nodes, key=lambda n: n.depth) - return deepestNode + candidates = sorted(nodes, key=lambda x: x.depth, reverse=True) + for deepestNode in candidates: + start = deepestNode.idx + end = deepestNode.idx + deepestNode.depth + if self._suffix_contains_terminal_symbol(start, end): + continue + else: + return deepestNode + return node def _generalized_word_starts(self, xs): """Helper method returns the starting indexes of strings in GST""" @@ -174,6 +217,15 @@ def _generalized_word_starts(self, xs): self.word_starts.append(i) i += len(xs[n]) + 1 + def _startswith(self, edge, prefix): + if self.input_type == bytes and isinstance(prefix, str): + prefix = bytes(prefix, 'UTF8') + if isinstance(prefix, str): + regex = r'^' + prefix + else: + regex = bytes(r'^', 'UTF8') + prefix + return re.match(regex, edge) + def find(self, y): """Returns starting position of the substring y in the string used for building the Suffix tree. @@ -185,7 +237,7 @@ def find(self, y): node = self.root while True: edge = self._edgeLabel(node, node.parent) - if edge.startswith(y): + if self._startswith(edge, y): return node.idx i = 0 @@ -207,7 +259,7 @@ def find_all(self, y): node = self.root while True: edge = self._edgeLabel(node, node.parent) - if edge.startswith(y): + if self._startswith(edge, y): break i = 0 @@ -242,15 +294,15 @@ def _terminalSymbolsGenerator(self): for i in UPPAs: yield (chr(i)) - raise ValueError("To many input strings.") + raise ValueError("Too many input strings.") class _SNode(): - __slots__ = ['_suffix_link', 'transition_links', 'idx', 'depth', 'parent', 'generalized_idxs'] + __slots__ = ['_suffix_link', 'transition_links', 'idx', 'depth', 'parent', 'generalized_idxs', 'input_type'] """Class representing a Node in the Suffix tree.""" - def __init__(self, idx=-1, parentNode=None, depth=-1): + def __init__(self, idx=-1, parentNode=None, depth=-1, input_type=None): # Links self._suffix_link = None self.transition_links = {} @@ -259,6 +311,7 @@ def __init__(self, idx=-1, parentNode=None, depth=-1): self.depth = depth self.parent = parentNode self.generalized_idxs = {} + self.input_type = input_type def __str__(self): return ("SNode: idx:" + str(self.idx) + " depth:" + str(self.depth) + @@ -274,6 +327,8 @@ def _get_suffix_link(self): return False def _get_transition_link(self, suffix): + if self.input_type == bytes and isinstance(suffix, str): + suffix = bytes(suffix, 'UTF8') return False if suffix not in self.transition_links else self.transition_links[suffix] def _add_transition_link(self, snode, suffix):