Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for raw bytes input #22

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 75 additions & 20 deletions suffix_trees/STree.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

class STree():
"""Class representing the suffix tree."""

Expand All @@ -7,6 +9,9 @@ def __init__(self, input=''):
self.root.idx = 0
self.root.parent = self.root
self.root._add_suffix_link(self.root)
self.input_type = None
self.terminal_symbol_length = 1
self.terminal_symbols = set()

if not input == '':
self.build(input)
Expand All @@ -17,12 +22,16 @@ def _check_input(self, input):
In case of an invalid input throws ValueError.
"""
if isinstance(input, str):
return 'st'
return (str, 'st')
elif isinstance(input, bytes):
return (bytes, 'st')
elif isinstance(input, list):
if all(isinstance(item, str) for item in input):
return 'gst'
return (str, 'gst')
elif all(isinstance(item, bytes) for item in input):
return (bytes, 'gst')

raise ValueError("String argument should be of type String or a list of strings")
raise ValueError("String argument should be of type 'Bytes' or 'String' or a list of those types")

def build(self, x):
"""Builds the Suffix tree on the given input.
Expand All @@ -31,12 +40,17 @@ def build(self, x):

:param x: String or List of Strings
"""
type = self._check_input(x)

if type == 'st':
x += next(self._terminalSymbolsGenerator())
(input_type, tree_type) = self._check_input(x)
self.input_type = input_type
if tree_type == 'st':
terminal_symbol = next(self._terminalSymbolsGenerator())
if self.input_type == bytes and isinstance(terminal_symbol, str):
terminal_symbol = bytes(terminal_symbol, 'UTF8')
self.terminal_symbol_length = len(terminal_symbol)
self.terminal_symbols.add(terminal_symbol)
x += terminal_symbol
self._build(x)
if type == 'gst':
elif tree_type == 'gst':
self._build_generalized(x)

def _build(self, x):
Expand All @@ -54,7 +68,7 @@ def _build_McCreight(self, x):
"""
u = self.root
d = 0
for i in range(len(x)):
for i in range(len(x) - (self.terminal_symbol_length - 1)):
while u.depth == d and u._has_transition(x[d + i]):
u = u._get_transition_link(x[d + i])
d = d + 1
Expand All @@ -73,15 +87,15 @@ def _build_McCreight(self, x):
def _create_node(self, x, u, d):
i = u.idx
p = u.parent
v = _SNode(idx=i, depth=d)
v = _SNode(idx=i, depth=d, input_type=self.input_type)
v._add_transition_link(u, x[i + d])
u.parent = v
p._add_transition_link(v, x[i + p.depth])
v.parent = p
return v

def _create_leaf(self, x, i, u, d):
w = _SNode()
w = _SNode(input_type=self.input_type)
w.idx = i
w.depth = len(x) - i
u._add_transition_link(w, x[i + d])
Expand Down Expand Up @@ -110,8 +124,17 @@ def _build_generalized(self, xs):
"""Builds a Generalized Suffix Tree (GST) from the array of strings provided.
"""
terminal_gen = self._terminalSymbolsGenerator()

_xs = ''.join([x + next(terminal_gen) for x in xs])
_xs = None
for x in xs:
terminal_symbol = next(terminal_gen)
if self.input_type == bytes and isinstance(terminal_symbol, str):
terminal_symbol = bytes(terminal_symbol, 'UTF8')
self.terminal_symbol_length = len(terminal_symbol)
self.terminal_symbols.add(terminal_symbol)
if not _xs:
_xs = x + terminal_symbol
else:
_xs += x + terminal_symbol
self.word = _xs
self._generalized_word_starts(xs)
self._build(_xs)
Expand All @@ -138,6 +161,19 @@ def _get_word_start_index(self, idx):
i += 1
return i

def _suffix_contains_terminal_symbol(self, start, end):
"""Validates if suffix was composed with multi-byte terminal symbol"""
for i in range(0, self.terminal_symbol_length):
if end+i <= len(self.word):
candidate_substring = None
if end-start == self.terminal_symbol_length:
candidate_substring = self.word[start+i+1:end+i+1]
else:
candidate_substring = self.word[start+i:end+i+1]
if candidate_substring in self.terminal_symbols:
return True
return False

def lcs(self, stringIdxs=-1):
"""Returns the Largest Common Substring of Strings provided in stringIdxs.
If stringIdxs is not provided, the LCS of all strings is returned.
Expand All @@ -163,8 +199,15 @@ def _find_lcs(self, node, stringIdxs):
if nodes == []:
return node

deepestNode = max(nodes, key=lambda n: n.depth)
return deepestNode
candidates = sorted(nodes, key=lambda x: x.depth, reverse=True)
for deepestNode in candidates:
start = deepestNode.idx
end = deepestNode.idx + deepestNode.depth
if self._suffix_contains_terminal_symbol(start, end):
continue
else:
return deepestNode
return node

def _generalized_word_starts(self, xs):
"""Helper method returns the starting indexes of strings in GST"""
Expand All @@ -174,6 +217,15 @@ def _generalized_word_starts(self, xs):
self.word_starts.append(i)
i += len(xs[n]) + 1

def _startswith(self, edge, prefix):
if self.input_type == bytes and isinstance(prefix, str):
prefix = bytes(prefix, 'UTF8')
if isinstance(prefix, str):
regex = r'^' + prefix
else:
regex = bytes(r'^', 'UTF8') + prefix
return re.match(regex, edge)

def find(self, y):
"""Returns starting position of the substring y in the string used for
building the Suffix tree.
Expand All @@ -185,7 +237,7 @@ def find(self, y):
node = self.root
while True:
edge = self._edgeLabel(node, node.parent)
if edge.startswith(y):
if self._startswith(edge, y):
return node.idx

i = 0
Expand All @@ -207,7 +259,7 @@ def find_all(self, y):
node = self.root
while True:
edge = self._edgeLabel(node, node.parent)
if edge.startswith(y):
if self._startswith(edge, y):
break

i = 0
Expand Down Expand Up @@ -242,15 +294,15 @@ def _terminalSymbolsGenerator(self):
for i in UPPAs:
yield (chr(i))

raise ValueError("To many input strings.")
raise ValueError("Too many input strings.")


class _SNode():
__slots__ = ['_suffix_link', 'transition_links', 'idx', 'depth', 'parent', 'generalized_idxs']
__slots__ = ['_suffix_link', 'transition_links', 'idx', 'depth', 'parent', 'generalized_idxs', 'input_type']

"""Class representing a Node in the Suffix tree."""

def __init__(self, idx=-1, parentNode=None, depth=-1):
def __init__(self, idx=-1, parentNode=None, depth=-1, input_type=None):
# Links
self._suffix_link = None
self.transition_links = {}
Expand All @@ -259,6 +311,7 @@ def __init__(self, idx=-1, parentNode=None, depth=-1):
self.depth = depth
self.parent = parentNode
self.generalized_idxs = {}
self.input_type = input_type

def __str__(self):
return ("SNode: idx:" + str(self.idx) + " depth:" + str(self.depth) +
Expand All @@ -274,6 +327,8 @@ def _get_suffix_link(self):
return False

def _get_transition_link(self, suffix):
if self.input_type == bytes and isinstance(suffix, str):
suffix = bytes(suffix, 'UTF8')
return False if suffix not in self.transition_links else self.transition_links[suffix]

def _add_transition_link(self, snode, suffix):
Expand Down