Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Yoon Kim committed Apr 7, 2019
1 parent b1eff12 commit 7ec40d4
Showing 1 changed file with 0 additions and 102 deletions.
102 changes: 0 additions & 102 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,35 +29,6 @@ def get_actions(tree, SHIFT = 0, REDUCE = 1, OPEN='(', CLOSE=')'):
assert(num_shift == num_reduce + 1)
return actions

def get_tag(word):
# need to manually replace punctuation with POS tags so evalb can correctly ignore them
if word == ',':
return ','
elif word == ':' or word == '--' or word == "..." or word == ';':
return ':'
elif word == '``' or word == '`':
return '``'
elif word == "''" or word == "'":
return "''"
elif word == '.' or word == "?" or word == "!":
return '.'
else:
return 'T'

def get_tree_evalb(actions, sent, SHIFT = 0, REDUCE = 1):
stack = []
pointer = 0
sent = ['(' + get_tag(s) + ' ' + s + ')' for s in sent]
for action in actions:
if action == SHIFT:
word = sent[pointer]
stack.append(word)
pointer += 1
elif action == REDUCE:
right = stack.pop()
left = stack.pop()
stack.append('(NT ' + left + ' ' + right + ')')
return stack[-1]

def get_tree(actions, sent = None, SHIFT = 0, REDUCE = 1):
#input action and sent (lists), e.g. S S R S S R R, A B C D
Expand All @@ -78,21 +49,6 @@ def get_tree(actions, sent = None, SHIFT = 0, REDUCE = 1):
assert(len(stack) == 1)
return stack[-1]

def get_depth(tree, SHIFT = 0, REDUCE = 1):
stack = []
depth = 0
max = 0
curr_max = 0
for c in tree:
if c == '(':
curr_max += 1
if curr_max > max:
max = curr_max
elif c == ')':
curr_max -= 1
assert(curr_max == 0)
return max

def get_spans(actions, SHIFT = 0, REDUCE = 1):
sent = list(range((len(actions)+1) // 2))
spans = []
Expand Down Expand Up @@ -172,24 +128,6 @@ def get_tree_from_binary_matrix(matrix, length):
tree[s] = span
tree[t] = span
return tree[0]


def get_child_idx(b, row, col):
found_left = False
k = 0
while not found_left:
k += 1
if b[row][col-k] == 1:
left_child_idx = (row, col-k)
found_left = True
found_right = False
k = 0
while not found_right:
k += 1
if b[row+k][col] == 1:
right_child_idx = (row+k, col)
found_right = True
return left_child_idx, right_child_idx

def get_nonbinary_spans(actions, SHIFT = 0, REDUCE = 1):
spans = []
Expand Down Expand Up @@ -231,43 +169,3 @@ def get_nonbinary_spans(actions, SHIFT = 0, REDUCE = 1):
assert(len(stack) == 1)
assert(num_shift == num_reduce + 1)
return spans, binary_actions, nonbinary_actions

def get_nonbinary_spans_label(actions, SHIFT = 0, REDUCE = 1):
spans = []
stack = []
pointer = 0
binary_actions = []
num_shift = 0
num_reduce = 0
for action in actions:
# print(action, stack)
if action == "SHIFT":
stack.append((pointer, pointer))
pointer += 1
binary_actions.append(SHIFT)
num_shift += 1
elif action[:3] == 'NT(':
label = "(" + action.split("(")[1][:-1]
stack.append(label)
elif action == "REDUCE":
right = stack.pop()
left = right
n = 1
while stack[-1][0] is not '(':
left = stack.pop()
n += 1
span = (left[0], right[1], stack[-1][1:])
if left[0] != right[1]:
spans.append(span)
stack.pop()
stack.append(span)
while n > 1:
n -= 1
binary_actions.append(REDUCE)
num_reduce += 1
else:
assert False
# print('after', stack)
assert(len(stack) == 1)
assert(num_shift == num_reduce + 1)
return spans, binary_actions

0 comments on commit 7ec40d4

Please sign in to comment.