Skip to content

Commit

Permalink
Return ans_votes
Browse files Browse the repository at this point in the history
  • Loading branch information
tongyx361 committed Sep 17, 2024
1 parent a119c6a commit 4b037ef
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions symeval/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,20 +324,23 @@ def batch_eq(

def batch_get_maj_answers(
self, answers_list: List[List[str]], accurate: bool = True
) -> List[List[str]]:
) -> T_Tuple[List[List[str]], dict[str, int]]:
"""Get the majority answers for a batch of answers."""
maj_answers_list: List[List[str]] = []
# Gather all pairs to evaluate
all_ans_pairs: List[T_Tuple[str, str]] = []

# Gather all unique pairs of answers to evaluate
for answers in answers_list:
for i_ans, answer in enumerate(answers):
for j_ans in range(i_ans):
all_ans_pairs.append((answer, answers[j_ans]))
all_ans_pairs.extend(
(answer, answers[j])
for j, answer in enumerate(answers)
if j < len(answers) - 1
)

all_ans_is: List[str]
all_ans_js: List[str]
# Unzip pairs for batch evaluation
all_ans_is, all_ans_js = zip(*all_ans_pairs)

# Evaluate equality of answer pairs
all_eqs: List[bool] = (
self.batch_eq(all_ans_is, all_ans_js)
if accurate
Expand All @@ -347,26 +350,31 @@ def batch_get_maj_answers(
all_pairs2eq: T_Dict[T_Tuple[str, str], bool] = dict(
zip(all_ans_pairs, all_eqs)
)
# Get the majority answers

# Get the majority answers for each set of answers
for answers in answers_list:
maj_answers: List[str] = []
ans_votes: T_Counter[str] = Counter()
for i_ans, answer in enumerate(answers):
exist: bool = False
for j_ans in range(i_ans):
exist_answer: str = answers[j_ans]
exist = all_pairs2eq[(answer, exist_answer)]
if exist:
ans_votes[exist_answer] += 1
break
if not exist:

for answer in answers:
exist_ans = next(
(
exist_answer
for exist_answer in ans_votes
if all_pairs2eq.get((answer, exist_answer), False)
),
None,
)
if exist_ans:
ans_votes[exist_ans] += 1
else:
ans_votes[answer] += 1

maj_ans = self.get_maj_ans_from_votes(ans_votes)
maj_answers.append(maj_ans)
maj_answers.append(self.get_maj_ans_from_votes(ans_votes))

maj_answers_list.append(maj_answers)

return maj_answers_list
return maj_answers_list, dict(ans_votes)


def batch_exec(
Expand Down

0 comments on commit 4b037ef

Please sign in to comment.