Skip to content

Commit

Permalink
Return ans_vote
Browse files Browse the repository at this point in the history
  • Loading branch information
tongyx361 committed Sep 17, 2024
1 parent a119c6a commit 7bc436b
Showing 1 changed file with 46 additions and 31 deletions.
77 changes: 46 additions & 31 deletions symeval/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,29 +206,33 @@ def get_maj_answers(self, answers: List[str]) -> List[str]:
"""Get the majority answers."""
maj_answers: List[str] = []

ans_votes: T_Counter[str] = Counter()
ans_vote: T_Counter[str] = Counter()
# Normalize all the answers
for answer in answers:
for exist_ans in ans_votes:
for exist_ans in ans_vote:
correct: bool
try:
correct = self.eq(answer, exist_ans)
except Exception:
correct = False
if correct:
ans_votes[exist_ans] += 1
ans_vote[exist_ans] += 1
break
else:
ans_votes[answer] += 1
maj_ans = self.get_maj_ans_from_votes(ans_votes)
ans_vote[answer] += 1
maj_ans = self.get_maj_ans_from_votes(ans_vote)
maj_answers.append(maj_ans)

return maj_answers

def get_maj_ans_from_votes(self, ans_votes: T_Counter[str]) -> str:
maj_ans = ans_votes.most_common(1)[0][0]
if maj_ans == "" and len(ans_votes) > 1:
maj_ans = ans_votes.most_common(2)[1][0]
def get_maj_ans_from_votes(
self, ans_vote: T_Union[T_Counter[str], dict[str, int]]
) -> str:
if isinstance(ans_vote, dict):
ans_vote = Counter(ans_vote)
maj_ans = ans_vote.most_common(1)[0][0]
if maj_ans == "" and len(ans_vote) > 1:
maj_ans = ans_vote.most_common(2)[1][0]
return maj_ans


Expand Down Expand Up @@ -324,20 +328,25 @@ def batch_eq(

def batch_get_maj_answers(
self, answers_list: List[List[str]], accurate: bool = True
) -> List[List[str]]:
) -> T_Tuple[List[List[str]], List[T_Dict[str, int]]]:
"""Get the majority answers for a batch of answers."""
maj_answers_list: List[List[str]] = []
# Gather all pairs to evaluate
ans_vote_list: List[T_Dict[str, int]] = []
# Gather all unique pairs of answers to evaluate

all_ans_pairs: List[T_Tuple[str, str]] = []

for answers in answers_list:
for i_ans, answer in enumerate(answers):
for j_ans in range(i_ans):
all_ans_pairs.append((answer, answers[j_ans]))
all_ans_pairs.extend(
(answer, answers[j])
for j, answer in enumerate(answers)
if j < len(answers) - 1
)

all_ans_is: List[str]
all_ans_js: List[str]
# Unzip pairs for batch evaluation
all_ans_is, all_ans_js = zip(*all_ans_pairs)

# Evaluate equality of answer pairs
all_eqs: List[bool] = (
self.batch_eq(all_ans_is, all_ans_js)
if accurate
Expand All @@ -347,26 +356,32 @@ def batch_get_maj_answers(
all_pairs2eq: T_Dict[T_Tuple[str, str], bool] = dict(
zip(all_ans_pairs, all_eqs)
)
# Get the majority answers

# Get the majority answers for each set of answers
for answers in answers_list:
maj_answers: List[str] = []
ans_votes: T_Counter[str] = Counter()
for i_ans, answer in enumerate(answers):
exist: bool = False
for j_ans in range(i_ans):
exist_answer: str = answers[j_ans]
exist = all_pairs2eq[(answer, exist_answer)]
if exist:
ans_votes[exist_answer] += 1
break
if not exist:
ans_votes[answer] += 1
ans_vote: T_Counter[str] = Counter()

for answer in answers:
exist_ans = next(
(
exist_answer
for exist_answer in ans_vote
if all_pairs2eq.get((answer, exist_answer), False)
),
None,
)
if exist_ans:
ans_vote[exist_ans] += 1
else:
ans_vote[answer] += 1

maj_answers.append(self.get_maj_ans_from_votes(ans_vote))

maj_ans = self.get_maj_ans_from_votes(ans_votes)
maj_answers.append(maj_ans)
maj_answers_list.append(maj_answers)
ans_vote_list.append(dict(ans_vote))

return maj_answers_list
return maj_answers_list, ans_vote_list


def batch_exec(
Expand Down

0 comments on commit 7bc436b

Please sign in to comment.