Skip to content

Commit

Permalink
get_errors script; edit_dist file
Browse files Browse the repository at this point in the history
  • Loading branch information
j-luo93 committed Sep 3, 2020
1 parent cb7282b commit 592762d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
14 changes: 14 additions & 0 deletions scripts/get_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import sys

import pandas as pd

if __name__ == "__main__":
in_path = sys.argv[1]
in_path = sys.argv[1]
out_path = sys.argv[2]
number = int(sys.argv[3])

df = pd.read_csv(in_path, sep='\t', keep_default_na=True)
correct = df['correct@5']
out_df = df[~correct].sample(n=number)
out_df.to_csv(out_path, sep='\t', index=None)
32 changes: 32 additions & 0 deletions sound_law/evaluate/edit_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Sequence, Union, overload

import numpy as np
from lingpy.align.pairwise import pw_align

from dev_misc.devlib import NDA
from editdistance import eval as ed_eval
from editdistance import eval_all as ed_eval_all

Number = Union[int, float]


def edit_dist(seq_0: str, seq_1: str, mode: str) -> Number:
"""A master function for dispatching different methods of computing edit distance."""
if mode == 'ed':
return ed_eval(seq_0, seq_1)
elif mode == 'global':
l0 = len(seq_0)
l1 = len(seq_1)
return max(l0, l1) - pw_align(seq_0, seq_1, mode='global')[-1]
else:
raise ValueError(f'Unrecognized value "{mode}" for mode.')


def edit_dist_all(seqs_0: Sequence[str], seqs_1: Sequence[str], mode: str) -> NDA:
if mode == 'ed':
return ed_eval_all(seqs_0, seqs_1)

ret = list()
for seq_0 in seqs_0:
ret.append([edit_dist(seq_0, seq_1, mode) for seq_1 in seqs_1])
return np.asarray(ret)

0 comments on commit 592762d

Please sign in to comment.