Skip to content
This repository has been archived by the owner on Jul 19, 2022. It is now read-only.

Commit

Permalink
Fix benchmark computations
Browse files Browse the repository at this point in the history
  • Loading branch information
SnowyCoder committed Feb 16, 2022
1 parent c79abe7 commit 16a6fb9
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 36 deletions.
8 changes: 4 additions & 4 deletions gamecompendium/app.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import copy
import os
from typing import Dict, Optional, Iterable, MutableSequence
from typing import Dict

from tqdm import tqdm

from whoosh_bugs import run as dont_delete_me_im_fixing_whoosh_bugs
from whoosh.filedb.filestore import Storage, FileStorage
from whoosh.index import Index
from whoosh.qparser import MultifieldParser, syntax, Plugin, QueryParser, MultifieldPlugin
from whoosh.qparser import syntax, Plugin, QueryParser, MultifieldPlugin
import re

from whoosh.searching import Searcher
Expand All @@ -30,6 +29,7 @@

dont_delete_me_im_fixing_whoosh_bugs()


class App:
sources: Dict[str, Source]
indexes: Dict[str, Index]
Expand Down Expand Up @@ -123,7 +123,7 @@ def evaluate(self, suite: BenchmarkSuite) -> list[BenchmarkResult]:
for row in topk:
relevance = next((d for hit, source in row.hits if (d := data.get((source, hit['id']))) is not None), 0)
entries.append(relevance)
res.append(BenchmarkResult(bench.query, entries))
res.append(BenchmarkResult(bench, entries))
return res

def prompt(self):
Expand Down
2 changes: 1 addition & 1 deletion gamecompendium/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class BenchmarkSuite:

@dataclass
class BenchmarkResult:
query: str
query: Benchmark
raw: list[int]


Expand Down
70 changes: 39 additions & 31 deletions gamecompendium/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@
from benchmark import parse_suite
import argparse
import math
import collections

INDEX_DIR = 'indexes'
# Everything >= 2 is "relevant" (used for everything except DCG-related stuff)
RELEVANCE_THRESHOLD = 2


def compute_discounted_cumulative_gain(data: list[int]) -> int:
if len(data) == 0:
return 0
return data[0] + sum([(data[i] / math.log(i + 1, 2)) for i in range(1, len(data))])


async def main():
Expand Down Expand Up @@ -49,52 +57,52 @@ async def main():
suite = parse_suite(fd)
res = app.evaluate(suite)
avg_precisions = []
interp_precisions = [0] * 10
for el in res:

print(f"{el}")

print(f"{el.query.query} : {[x.relevance for x in el.query.scores]} {el.raw}")

# DCG
val = (el.raw[0] + sum([(el.raw[i]/math.log(i+1,2)) for i in range(1,len(el.raw))]) )
val = compute_discounted_cumulative_gain(el.raw)
print(f"DCG: {val}")

# IDEAL DCG
ideal_list = sorted(el.raw,reverse=True)
val_ideal = (ideal_list[0] + sum([(ideal_list[i]/math.log(i+1,2)) for i in range(1,len(ideal_list))]) )
ideal_list = sorted([x.relevance for x in el.query.scores], reverse=True)
val_ideal = compute_discounted_cumulative_gain(ideal_list)

print(f"IDEAL DCG: {val_ideal}")
print(f"NDCG: {val/val_ideal}")

# NATURAL PRECISION
natural_pr = {}
doc_count = 0
tot_rel = sum([1 for elem in el.raw if elem != 0])
for i in range(len(el.raw)):
if el.raw[i] != 0:
doc_count += 1
precision = doc_count/(i+1)
natural_pr[doc_count/tot_rel] = precision
natural_pr = []
tot_rel = sum([x.relevance >= RELEVANCE_THRESHOLD for x in el.query.scores])
for i, entry in enumerate(el.raw):
if entry >= RELEVANCE_THRESHOLD:
precision = (len(natural_pr) + 1)/(i + 1)
natural_pr.append(precision)
print("Natural precision: ")
print(" | ".join([f"{key}:{value}" for key,value in natural_pr.items()]))
print(" | ".join([f"{(i + 1) / tot_rel}:{value}" for i, value in enumerate(natural_pr)]))

# STANDARD PRECISION
precisions = {}
for i in range(len(el.raw)-1,-1,-1):
maxval = max([value for key,value in natural_pr.items() if key >= (i/10)])
precisions[(i+1)/10] = maxval
ord_prec = collections.OrderedDict(sorted(precisions.items()))
precisions = [0.0] * 10
for i in range(10):
maxval = max([value for j, value in enumerate(natural_pr) if (j + 1) / tot_rel >= (i + 1) / 10], default=0)
precisions[i] = maxval
interp_precisions[i] += maxval
print("Standard precision: ")
print(" | ".join([f"{key}:{value}" for key,value in ord_prec.items()]))
avg_prc = sum(list(natural_pr.values()))/tot_rel
print(" | ".join([f"{(i+1)/10}:{value}" for i, value in enumerate(precisions)]))

avg_prc = sum(natural_pr) / tot_rel
print(f"Average non-interpolated precision: {avg_prc}")
avg_int_prc = sum(list(precisions.values()))/10

avg_int_prc = sum(precisions) / 10
avg_precisions.append(avg_int_prc)
print(f"Average interpolated precision: {avg_int_prc}")
print("\n")

mean_avg = sum(avg_precisions)/len(res)
print(f"Mean average precision: {mean_avg}")
print("Average Standard precision: ")
print(" | ".join([f"{(key + 1) / 10}:{value / len(interp_precisions)}" for key, value in enumerate(interp_precisions)]))

else:
print("Unknown action: " + args.action)
Expand Down

0 comments on commit 16a6fb9

Please sign in to comment.