Skip to content

Commit

Permalink
correlation
Browse files Browse the repository at this point in the history
  • Loading branch information
j-luo93 committed Apr 12, 2021
1 parent 9267108 commit 0527ebd
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 34 deletions.
82 changes: 59 additions & 23 deletions scripts/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,33 +874,69 @@ def get_group_chart(grid_df, min_value: float, max_value: float, title: str = ''
if average_col:
st.write(to_inspect.pivot_table(index=average_col, values='value', aggfunc='mean'))
if st.checkbox('show matching'):
path = st.selectbox('which run', [None] + selected_runs)
if path is not None:
auc_score, match_df = read_matching_metrics(path)

chart = alt.Chart(match_df).mark_rect().encode(
x='max_power_set_size:O',
y='k_matches:O',
color='score:Q',
tooltip=[
alt.Tooltip('max_power_set_size:O', title='max_power_set_size'),
alt.Tooltip('k_matches:O', title='k_matches'),
alt.Tooltip('score:Q', title='score'),
]
).facet(column='match_proportion')
st.write(chart)
# Show correlation.
# path = st.selectbox('which run', [None] + selected_runs)
# if path is not None:
# auc_score, match_df = read_matching_metrics(path)

# chart = alt.Chart(match_df).mark_rect().encode(
# x='max_power_set_size:O',
# y='k_matches:O',
# color='score:Q',
# tooltip=[
# alt.Tooltip('max_power_set_size:O', title='max_power_set_size'),
# alt.Tooltip('k_matches:O', title='k_matches'),
# alt.Tooltip('score:Q', title='score'),
# ]
# ).facet(column='match_proportion')
# st.write(chart)
# corr_data = list()
# for i, run in enumerate(selected_runs, 1):
# auc_score, match_df = read_matching_metrics(run)
# lang = re.search(r'OPRLPgmc(\w\w\w)', str(run)).group(1).lower()

# is_complete = not bool((match_df['score'] == -1).sum())
# assert is_complete
# event_df = load_event(run)
# best_score = event_df[event_df['tag'] == 'best_score']['value'].max()
# corr_data.append((best_score, auc_score, lang))
# corr_df = pd.DataFrame(corr_data, columns=['best_score', 'auc_score', 'lang'])

# Use more correlation data by adding truncated paths.
corr_data = list()
lang2length = {'got': 20, 'non': 40, 'ang': 60}
for i, run in enumerate(selected_runs, 1):
auc_score, match_df = read_matching_metrics(run)
lang = re.search(r'OPRLPgmc(\w\w\w)', str(run)).group(1).lower()
best_run = int((Path(run) / 'best_run').open('r').read(-1).strip())
length = lang2length[lang]
records = list()
score_path = f'{run}/eval/{best_run}.path.scores'
with open(score_path) as fin:
truncated_dists = list()
for line in fin:
truncated_dists.append(float(line.strip()))
start_dist = truncated_dists[0]
# last_record = None
for l in range(5, length + 1, 5):
if l >= len(truncated_dists):
break
record = {'truncate_length': l,
'best_score': 1.0 - truncated_dists[l] / start_dist, 'lang': lang, 'run': run}
scores = [1.0]
for m in [0.2, 0.4, 0.6, 0.8, 1.0]:
match_score = read_matching_score(f'{run}/eval/{m}-100-10-{l}.pkl')
scores.append(match_score)
record[f'match_{m}'] = match_score
assert all(score > -1 for score in scores)
auc_score = auc([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], scores)
record['auc_score'] = auc_score
# last_record = record
corr_data.append(record)
# corr_data.append(last_record)
corr_df = pd.DataFrame(corr_data)
corr_df.to_csv('corr_df.tsv', sep='\t', index=False)
st.write(corr_df)

is_complete = not bool((match_df['score'] == -1).sum())
assert is_complete
event_df = load_event(run)
best_score = event_df[event_df['tag'] == 'best_score']['value'].max()
corr_data.append((best_score, auc_score, lang))

corr_df = pd.DataFrame(corr_data, columns=['best_score', 'auc_score', 'lang'])
chart = alt.Chart(corr_df).mark_point().encode(
x='best_score:Q',
y='auc_score:Q',
Expand Down
8 changes: 8 additions & 0 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import sound_law.rl.rule as rule
from pathlib import Path
from dev_misc import add_argument, g

if __name__ == "__main__":
add_argument("calc_metric", dtype=bool, default=False, msg="Whether to calculate the metrics.")
add_argument("out_path", dtype=str, msg="Path to the output file.")

manager, gold, states, refs = rule.simulate()
initial_state = states[0]
if g.in_path:
assert len(gold) == len(states) - 1
if g.out_path:
with Path(g.out_path).open('w', encoding='utf8') as fout:
for state in states:
fout.write(f'{state.dist}\n')

if g.calc_metric:
# compute the similarity between the candidate ruleset and the gold standard ruleset
Expand Down
42 changes: 32 additions & 10 deletions sound_law/evaluate/ilp.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def match_rulesets(gold: List[List[SoundChangeAction]],
max_power_set_size: int = 3,
use_greedy_growth: bool = False,
interpret_matching: bool = False,
silent: bool = False) -> Tuple[List[Tuple[int, List[int]]], int, float, float, Dict[int, int]]:
silent: bool = False,
null_only: bool = False) -> Tuple[List[Tuple[int, List[int]]], int, float, float, Dict[int, int]]:
'''Finds the optimal matching of rule blocks in the gold ruleset to 0, 1, or 2 rules in the candidate ruleset. Frames the problem as an integer linear program. Returns a list of tuples with the matching.'''

solver = pywraplp.Solver.CreateSolver('SCIP') # TODO investigate other solvers
Expand Down Expand Up @@ -120,6 +121,18 @@ def match_rulesets(gold: List[List[SoundChangeAction]],
pass
else:
number_active_gold_blocks += 1
# Match gold with a null candidate rule. Note that there is no constraint on the candidate side.
var_name = f'pss0_{i},(-1)' # -1 stands for the null candidate rule.
v[var_name] = null_match = solver.IntVar(0, 1, var_name)
c[f'gold_{i}'].SetCoefficient(null_match, 1)
c['min_match'].SetCoefficient(null_match, null_cost)
objective.SetCoefficient(null_match, null_cost)

if null_only:
# update the state and continue onto the next block in gold
curr_state = gold_state
continue

# actually loop over the variables and create variables for this block

def generate_match_candidates(var_name_prefix: str, power_set_size: int,
Expand Down Expand Up @@ -191,13 +204,6 @@ def update_top_candidates(highest_cost: float, new_candidates: List[MatchCandida
objective.SetCoefficient(v[var_name], cost)
size_cnt[len(cand_rules)] += 1

# Match gold with a null candidate rule. Note that there is no constraint on the candidate side.
var_name = f'pss0_{i},(-1)' # -1 stands for the null candidate rule.
v[var_name] = null_match = solver.IntVar(0, 1, var_name)
c[f'gold_{i}'].SetCoefficient(null_match, 1)
c['min_match'].SetCoefficient(null_match, null_cost)
objective.SetCoefficient(null_match, null_cost)

# update the state and continue onto the next block in gold
curr_state = gold_state

Expand All @@ -218,9 +224,22 @@ def update_top_candidates(highest_cost: float, new_candidates: List[MatchCandida

# reconstruct the solution and return it
final_value = solver.Objective().Value()
max_cost = total_null_costs
if not null_only:
_, _, max_cost, _, _ = match_rulesets(gold,
cand,
env,
match_proportion,
k_matches,
max_power_set_size,
use_greedy_growth,
interpret_matching,
silent,
null_only=True)

if not silent:
print('Minimum objective function value = %f' % final_value)
print('Minimum objective function value percentage = %f' % (1.0 - final_value / total_null_costs))
print('Minimum objective function value percentage = %f' % (1.0 - final_value / max_cost))

# interpret solution as a matching, returning a list pairing indices of blocks in gold to a list of indices of matched rules in cand
matching = []
Expand Down Expand Up @@ -249,7 +268,7 @@ def update_top_candidates(highest_cost: float, new_candidates: List[MatchCandida
print('matched to rules:', cand_rules)
print('with dist', str(cost))

return matching, status, final_value, total_null_costs, size_cnt
return matching, status, final_value, max_cost, size_cnt


if __name__ == "__main__":
Expand All @@ -261,11 +280,14 @@ def update_top_candidates(highest_cost: float, new_candidates: List[MatchCandida
add_argument('max_power_set_size', dtype=int, default=3, msg='Maximum power set size.')
add_argument("use_greedy_growth", dtype=bool, default=False, msg="Flag to grow the kept candidates greedily.")
add_argument("silent", dtype=bool, default=False, msg="Flag to suppress printing.")
add_argument('cand_length', dtype=int, default=0, msg='Only take the first n candidate rules if positive.')

manager, gold, states, refs = rule.simulate()
initial_state = states[0]

cand = read_rules_from_txt(g.cand_path)
if g.cand_length > 0:
cand = cand[:g.cand_length]
# gold = read_rules_from_txt('data/toy_gold_rules.txt')

# turn gold rules into singleton lists since we expect gold to be in the form of blocks
Expand Down
2 changes: 1 addition & 1 deletion sound_law/rl/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def simulate(raw_inputs: Optional[List[Tuple[List[str], List[str], List[str]]]]
elif g.in_path:
with open(g.in_path, 'r', encoding='utf8') as fin:
lines = [line.strip() for line in fin.readlines()]
gold = get_actions(lines, range(len(lines)))
gold = get_actions(lines)
else:
df = pd.read_csv('data/test_annotations.csv')
df = df.dropna(subset=['ref no.'])
Expand Down

0 comments on commit 0527ebd

Please sign in to comment.