-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove strip from ProceduralDataset::core score_answer() (#250)
* remove strip from ProceduralDataset::core score_answer(), strip in extract answer (optional, default=True) * test: Move test_extract_answer() from test_dataset.py to test_utils.py * refactor: Improve decimal reward computation with more flexible comparison * fix: Implement rounding for format_number when round_if_needed is True * test: Add test case for compute_decimal_reward with sign and zeros
- Loading branch information
1 parent
a66a7e7
commit 24828e1
Showing
6 changed files
with
80 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import pytest | ||
|
||
from reasoning_gym.utils import compute_decimal_reward, extract_answer, format_number | ||
|
||
|
||
def test_extract_answer(): | ||
assert extract_answer("This is a text. <final_answer>1234</final_answer>", tag_name="final_answer") == "1234" | ||
|
||
# ignore whitespaces | ||
assert extract_answer("This is a text. <answer>\n1234 </answer>", tag_name="answer", strip=True) == "1234" | ||
|
||
|
||
def test_format_number(): | ||
# Test integers | ||
assert format_number(42) == "42" | ||
assert format_number(42.0) == "42" | ||
|
||
# Test decimals | ||
assert format_number(3.14) == "3.14" | ||
assert format_number(3.10) == "3.1" | ||
assert format_number(3.00) == "3" | ||
|
||
# Test with max_decimals (rounding) | ||
assert format_number(3.14159, max_decimals=4, round_if_needed=True) == "3.1416" | ||
|
||
# Test with trailing zeros | ||
assert format_number(5.5000) == "5.5" | ||
|
||
# Test error cases | ||
with pytest.raises(ValueError): | ||
format_number(3.14159, max_decimals=2) | ||
|
||
|
||
def test_compute_decimal_reward(): | ||
# Test exact matches | ||
assert compute_decimal_reward("42", "42") == 1.0 | ||
assert compute_decimal_reward("3.14", "3.14") == 1.0 | ||
|
||
# Test with commas | ||
assert compute_decimal_reward("1,000", "1000") == 1.0 | ||
assert compute_decimal_reward("1,000", "1000", strip_commas=False) < 1.0 | ||
|
||
# Test with sign, leading zeros, and trailing decimals | ||
assert compute_decimal_reward("+0001,000.00", "1000") == 1.0 | ||
|
||
# Test partial matches | ||
assert compute_decimal_reward("The answer is 42", "42") < 1.0 | ||
assert compute_decimal_reward("The answer is 42", "42") > 0.01 | ||
|
||
# Test invalid answers | ||
assert compute_decimal_reward(None, "42") == 0.0 | ||
assert compute_decimal_reward("", "42") == 0.0 | ||
assert compute_decimal_reward("not a number", "42") == 0.01 |