diff --git a/utils.py b/utils.py index 99928cb..95e07aa 100644 --- a/utils.py +++ b/utils.py @@ -29,35 +29,6 @@ def get_actions(tree, SHIFT = 0, REDUCE = 1, OPEN='(', CLOSE=')'): assert(num_shift == num_reduce + 1) return actions -def get_tag(word): - # need to manually replace punctuation with POS tags so evalb can correctly ignore them - if word == ',': - return ',' - elif word == ':' or word == '--' or word == "..." or word == ';': - return ':' - elif word == '``' or word == '`': - return '``' - elif word == "''" or word == "'": - return "''" - elif word == '.' or word == "?" or word == "!": - return '.' - else: - return 'T' - -def get_tree_evalb(actions, sent, SHIFT = 0, REDUCE = 1): - stack = [] - pointer = 0 - sent = ['(' + get_tag(s) + ' ' + s + ')' for s in sent] - for action in actions: - if action == SHIFT: - word = sent[pointer] - stack.append(word) - pointer += 1 - elif action == REDUCE: - right = stack.pop() - left = stack.pop() - stack.append('(NT ' + left + ' ' + right + ')') - return stack[-1] def get_tree(actions, sent = None, SHIFT = 0, REDUCE = 1): #input action and sent (lists), e.g. S S R S S R R, A B C D @@ -78,21 +49,6 @@ def get_tree(actions, sent = None, SHIFT = 0, REDUCE = 1): assert(len(stack) == 1) return stack[-1] -def get_depth(tree, SHIFT = 0, REDUCE = 1): - stack = [] - depth = 0 - max = 0 - curr_max = 0 - for c in tree: - if c == '(': - curr_max += 1 - if curr_max > max: - max = curr_max - elif c == ')': - curr_max -= 1 - assert(curr_max == 0) - return max - def get_spans(actions, SHIFT = 0, REDUCE = 1): sent = list(range((len(actions)+1) // 2)) spans = [] @@ -172,24 +128,6 @@ def get_tree_from_binary_matrix(matrix, length): tree[s] = span tree[t] = span return tree[0] - - -def get_child_idx(b, row, col): - found_left = False - k = 0 - while not found_left: - k += 1 - if b[row][col-k] == 1: - left_child_idx = (row, col-k) - found_left = True - found_right = False - k = 0 - while not found_right: - k += 1 - if b[row+k][col] == 1: - right_child_idx = (row+k, col) - found_right = True - return left_child_idx, right_child_idx def get_nonbinary_spans(actions, SHIFT = 0, REDUCE = 1): spans = [] @@ -231,43 +169,3 @@ def get_nonbinary_spans(actions, SHIFT = 0, REDUCE = 1): assert(len(stack) == 1) assert(num_shift == num_reduce + 1) return spans, binary_actions, nonbinary_actions - -def get_nonbinary_spans_label(actions, SHIFT = 0, REDUCE = 1): - spans = [] - stack = [] - pointer = 0 - binary_actions = [] - num_shift = 0 - num_reduce = 0 - for action in actions: - # print(action, stack) - if action == "SHIFT": - stack.append((pointer, pointer)) - pointer += 1 - binary_actions.append(SHIFT) - num_shift += 1 - elif action[:3] == 'NT(': - label = "(" + action.split("(")[1][:-1] - stack.append(label) - elif action == "REDUCE": - right = stack.pop() - left = right - n = 1 - while stack[-1][0] is not '(': - left = stack.pop() - n += 1 - span = (left[0], right[1], stack[-1][1:]) - if left[0] != right[1]: - spans.append(span) - stack.pop() - stack.append(span) - while n > 1: - n -= 1 - binary_actions.append(REDUCE) - num_reduce += 1 - else: - assert False - # print('after', stack) - assert(len(stack) == 1) - assert(num_shift == num_reduce + 1) - return spans, binary_actions