Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ASR WER and MER metrics #3296

Merged
merged 3 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions src/helm/benchmark/metrics/evaluate_reference_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,10 @@ def cider(gold: str, pred: str) -> float:
return average_score


def wa_score(gold: str, pred: str) -> float:
# Word Accuracy (WA) equals to 1 - word error rate (WER), which is a common
def wer_score(gold: str, pred: str) -> float:
# Word Error Rate (WER), which is a common
# metric used to evaluate the accuracy of speech recognition systems.
# Note that this metric could be negative because the WER might be greater than 1.
# The lower the better. The WER might be greater than 1.
# https://huggingface.co/learn/audio-course/en/chapter5/evaluation#word-error-rate
try:
from jiwer import wer
Expand All @@ -230,13 +230,13 @@ def wa_score(gold: str, pred: str) -> float:
return 0
gold = normalize_text(gold, should_remove_articles=False)
pred = normalize_text(pred, should_remove_articles=False)
wer_ret = 1 - wer(gold, pred)
wer_ret = wer(gold, pred)
return wer_ret


def ma_score(gold: str, pred: str) -> float:
# Match Accuracy (MA) equals to 1 - match error rate (MER), which is for evaluating the accuracy of
# speech recognition systems.
def mer_score(gold: str, pred: str) -> float:
# Match Error Rate (MER), which is for evaluating the error rate of
# speech recognition systems. The lower the better.
try:
from jiwer import mer
except ModuleNotFoundError as e:
Expand All @@ -253,7 +253,7 @@ def ma_score(gold: str, pred: str) -> float:

def wip_score(gold: str, pred: str) -> float:
# Word information preservation (WIP) for evaluating the preserved information of speech
# recognition systems.
# recognition systems. The higher the better.
try:
from jiwer import wip
except ModuleNotFoundError as e:
Expand All @@ -268,9 +268,9 @@ def wip_score(gold: str, pred: str) -> float:
return wip_ret


def ca_score(gold: str, pred: str) -> float:
# Character accuracy (CA) equals to character error rate (CER) for evaluating the accuracy
# of speech recognition systems.
def cer_score(gold: str, pred: str) -> float:
# Character Error Rate (CER) for evaluating the accuracy
# of speech recognition systems. The lower the better.
try:
from jiwer import cer
except ModuleNotFoundError as e:
Expand All @@ -285,22 +285,22 @@ def ca_score(gold: str, pred: str) -> float:
return cer_ret


def chinese_wa_score(gold: str, pred: str) -> float:
def chinese_wer_score(gold: str, pred: str) -> float:
try:
import jieba
except ModuleNotFoundError as e:
handle_module_not_found_error(e, ["audiolm"])

return wa_score(" ".join(jieba.cut(gold)), " ".join(jieba.cut(pred)))
return wer_score(" ".join(jieba.cut(gold)), " ".join(jieba.cut(pred)))


def chinese_ma_score(gold: str, pred: str) -> float:
def chinese_mer_score(gold: str, pred: str) -> float:
try:
import jieba
except ModuleNotFoundError as e:
handle_module_not_found_error(e, ["audiolm"])

return ma_score(" ".join(jieba.cut(gold)), " ".join(jieba.cut(pred)))
return mer_score(" ".join(jieba.cut(gold)), " ".join(jieba.cut(pred)))


def chinese_wip_score(gold: str, pred: str) -> float:
Expand All @@ -312,13 +312,13 @@ def chinese_wip_score(gold: str, pred: str) -> float:
return wip_score(" ".join(jieba.cut(gold)), " ".join(jieba.cut(pred)))


def chinese_ca_score(gold: str, pred: str) -> float:
def chinese_cer_score(gold: str, pred: str) -> float:
try:
import jieba
except ModuleNotFoundError as e:
handle_module_not_found_error(e, ["audiolm"])

return ca_score(" ".join(jieba.cut(gold)), " ".join(jieba.cut(pred)))
return cer_score(" ".join(jieba.cut(gold)), " ".join(jieba.cut(pred)))


def extract_set_from_text(
Expand Down Expand Up @@ -471,14 +471,14 @@ def compute_metrics_helper(
"chinese_rouge_2": get_chinese_rouge_function("rouge2"),
"cleva_math_result_match": cleva_math_result_match,
"absolute_value_difference": absolute_value_difference,
"wa_score": wa_score,
"ma_score": ma_score,
"wer_score": wer_score,
"mer_score": mer_score,
"wip_score": wip_score,
"ca_score": ca_score,
"chinese_wa_score": chinese_wa_score,
"chinese_ma_score": chinese_ma_score,
"cer_score": cer_score,
"chinese_wer_score": chinese_wer_score,
"chinese_mer_score": chinese_mer_score,
"chinese_wip_score": chinese_wip_score,
"chinese_ca_score": chinese_ca_score,
"chinese_cer_score": chinese_cer_score,
}

stats: List[Stat] = []
Expand Down
2 changes: 1 addition & 1 deletion src/helm/benchmark/presentation/run_entries_speech.conf
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ entries: [
{description: "common_voice_15:language=German,model=audiolm", priority: 1}
{description: "common_voice_15:language=French,model=audiolm", priority: 1}

]
]
4 changes: 2 additions & 2 deletions src/helm/benchmark/run_specs/audio_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_machine_translation_metric_specs() -> List[MetricSpec]:


def _get_audio_recognition_metric_specs() -> List[MetricSpec]:
return get_basic_metric_specs(["wa_score", "ma_score", "wip_score", "ca_score"])
return get_basic_metric_specs(["wer_score", "mer_score", "wip_score", "cer_score"])


def _get_open_ended_generation_metric_specs() -> List[MetricSpec]:
Expand All @@ -88,7 +88,7 @@ def _get_open_ended_generation_metric_specs() -> List[MetricSpec]:


def _get_chinese_audio_recognition_metric_specs() -> List[MetricSpec]:
return get_basic_metric_specs(["chinese_wa_score", "chinese_ma_score", "chinese_wip_score", "chinese_ca_score"])
return get_basic_metric_specs(["chinese_wer_score", "chinese_mer_score", "chinese_wip_score", "chinese_cer_score"])


########################################################################################################################
Expand Down
52 changes: 44 additions & 8 deletions src/helm/benchmark/static/schema_speech.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,52 @@ metrics:
lower_is_better: false

# Speech Recognition metrics
- name: wa_score
display_name: Word Accuracy Score
short_display_name: Word Accuracy Score
description: Word Accuracy based on the Word Error Rate.
- name: wer_score
display_name: Word Error Rate
short_display_name: WER
description: Word error rate between model predictions and ground truth answers for ASR tasks.
lower_is_better: true

- name: mer_score
display_name: Match Error Rate
short_display_name: MER
description: Word match error rate between model predictions and ground truth answers.
lower_is_better: true

- name: wip_score
display_name: Word Information Preservation
short_display_name: WIP
description: Word information preservation (WIP) for evaluating the preserved information of ASR.
lower_is_better: false

- name: cer_score
display_name: Character Error Rate
short_display_name: CER
description: Character error rate (CER) for evaluating the accuracy of ASR.
lower_is_better: true

- name: chinese_wer_score
display_name: Chinese Word Error Rate
short_display_name: Chinese WER
description: Chinese word error rate between model predictions and ground truth answers for ASR tasks.
lower_is_better: true

- name: chinese_mer_score
display_name: Chinese Match Error Rate
short_display_name: Chinese MER
description: Chinese word match error rate between model predictions and ground truth answers.
lower_is_better: true

- name: chinese_wip_score
display_name: Chinese Word Information Preservation
short_display_name: Chinese WIP
description: Chinese word information preservation (WIP) for evaluating the preserved information of ASR.
lower_is_better: false

- name: CA
display_name: WA
short_display_name: WA
description: Word Accuracy based on the Word Error Rate.
- name: chinese_cer_score
display_name: Chinese Character Error Rate
short_display_name: Chinese CER
description: Chinese character error rate (CER) for evaluating the accuracy of Chiese ASR.
lower_is_better: true

############################################################
Expand Down
Loading