diff --git a/gamecompendium/app.py b/gamecompendium/app.py index f742264..5b89fce 100644 --- a/gamecompendium/app.py +++ b/gamecompendium/app.py @@ -1,13 +1,12 @@ -import copy import os -from typing import Dict, Optional, Iterable, MutableSequence +from typing import Dict from tqdm import tqdm from whoosh_bugs import run as dont_delete_me_im_fixing_whoosh_bugs from whoosh.filedb.filestore import Storage, FileStorage from whoosh.index import Index -from whoosh.qparser import MultifieldParser, syntax, Plugin, QueryParser, MultifieldPlugin +from whoosh.qparser import syntax, Plugin, QueryParser, MultifieldPlugin import re from whoosh.searching import Searcher @@ -30,6 +29,7 @@ dont_delete_me_im_fixing_whoosh_bugs() + class App: sources: Dict[str, Source] indexes: Dict[str, Index] @@ -123,7 +123,7 @@ def evaluate(self, suite: BenchmarkSuite) -> list[BenchmarkResult]: for row in topk: relevance = next((d for hit, source in row.hits if (d := data.get((source, hit['id']))) is not None), 0) entries.append(relevance) - res.append(BenchmarkResult(bench.query, entries)) + res.append(BenchmarkResult(bench, entries)) return res def prompt(self): diff --git a/gamecompendium/benchmark.py b/gamecompendium/benchmark.py index e164e6e..2303879 100644 --- a/gamecompendium/benchmark.py +++ b/gamecompendium/benchmark.py @@ -24,7 +24,7 @@ class BenchmarkSuite: @dataclass class BenchmarkResult: - query: str + query: Benchmark raw: list[int] diff --git a/gamecompendium/main.py b/gamecompendium/main.py index 5f5f408..726137f 100644 --- a/gamecompendium/main.py +++ b/gamecompendium/main.py @@ -3,8 +3,16 @@ from benchmark import parse_suite import argparse import math -import collections + INDEX_DIR = 'indexes' +# Everything >= 2 is "relevant" (used for everything except DCG-related stuff) +RELEVANCE_THRESHOLD = 2 + + +def compute_discounted_cumulative_gain(data: list[int]) -> int: + if len(data) == 0: + return 0 + return data[0] + sum([(data[i] / math.log(i + 1, 2)) for i in range(1, len(data))]) async def main(): @@ -49,52 +57,52 @@ async def main(): suite = parse_suite(fd) res = app.evaluate(suite) avg_precisions = [] + interp_precisions = [0] * 10 for el in res: - - print(f"{el}") - + print(f"{el.query.query} : {[x.relevance for x in el.query.scores]} {el.raw}") + # DCG - val = (el.raw[0] + sum([(el.raw[i]/math.log(i+1,2)) for i in range(1,len(el.raw))]) ) + val = compute_discounted_cumulative_gain(el.raw) print(f"DCG: {val}") - + # IDEAL DCG - ideal_list = sorted(el.raw,reverse=True) - val_ideal = (ideal_list[0] + sum([(ideal_list[i]/math.log(i+1,2)) for i in range(1,len(ideal_list))]) ) - + ideal_list = sorted([x.relevance for x in el.query.scores], reverse=True) + val_ideal = compute_discounted_cumulative_gain(ideal_list) + print(f"IDEAL DCG: {val_ideal}") print(f"NDCG: {val/val_ideal}") - + # NATURAL PRECISION - natural_pr = {} - doc_count = 0 - tot_rel = sum([1 for elem in el.raw if elem != 0]) - for i in range(len(el.raw)): - if el.raw[i] != 0: - doc_count += 1 - precision = doc_count/(i+1) - natural_pr[doc_count/tot_rel] = precision + natural_pr = [] + tot_rel = sum([x.relevance >= RELEVANCE_THRESHOLD for x in el.query.scores]) + for i, entry in enumerate(el.raw): + if entry >= RELEVANCE_THRESHOLD: + precision = (len(natural_pr) + 1)/(i + 1) + natural_pr.append(precision) print("Natural precision: ") - print(" | ".join([f"{key}:{value}" for key,value in natural_pr.items()])) - + print(" | ".join([f"{(i + 1) / tot_rel}:{value}" for i, value in enumerate(natural_pr)])) + # STANDARD PRECISION - precisions = {} - for i in range(len(el.raw)-1,-1,-1): - maxval = max([value for key,value in natural_pr.items() if key >= (i/10)]) - precisions[(i+1)/10] = maxval - ord_prec = collections.OrderedDict(sorted(precisions.items())) + precisions = [0.0] * 10 + for i in range(10): + maxval = max([value for j, value in enumerate(natural_pr) if (j + 1) / tot_rel >= (i + 1) / 10], default=0) + precisions[i] = maxval + interp_precisions[i] += maxval print("Standard precision: ") - print(" | ".join([f"{key}:{value}" for key,value in ord_prec.items()])) - - avg_prc = sum(list(natural_pr.values()))/tot_rel + print(" | ".join([f"{(i+1)/10}:{value}" for i, value in enumerate(precisions)])) + + avg_prc = sum(natural_pr) / tot_rel print(f"Average non-interpolated precision: {avg_prc}") - - avg_int_prc = sum(list(precisions.values()))/10 + + avg_int_prc = sum(precisions) / 10 avg_precisions.append(avg_int_prc) print(f"Average interpolated precision: {avg_int_prc}") print("\n") - + mean_avg = sum(avg_precisions)/len(res) print(f"Mean average precision: {mean_avg}") + print("Average Standard precision: ") + print(" | ".join([f"{(key + 1) / 10}:{value / len(interp_precisions)}" for key, value in enumerate(interp_precisions)])) else: print("Unknown action: " + args.action)