-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #65 from zafstojano/env/group-anagrams
Group Anagrams together
- Loading branch information
Showing
7 changed files
with
30,644 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 21, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import json\n", | ||
"from collections import defaultdict" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 22, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"letters = [chr(letter) for letter in range(ord(\"a\"), ord(\"z\") + 1)]\n", | ||
"print(letters)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 23, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"370105\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# The file `words_alpha.txt` has been obtained from https://github.com/dwyl/english-words \n", | ||
"with open(\"./reasoning_gym/data/words_alpha.txt\") as f:\n", | ||
" words = f.read().splitlines()\n", | ||
"print(len(words))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 24, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"30177\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"def group_anagrams(words: list[str]) -> dict[tuple[int], list[str]]:\n", | ||
" \n", | ||
" def _codify(word):\n", | ||
" code = [0] * 26\n", | ||
" for c in word:\n", | ||
" code[ord(c)-ord('a')] += 1\n", | ||
" return tuple(code)\n", | ||
"\n", | ||
" res = defaultdict(list)\n", | ||
"\n", | ||
" for word in words:\n", | ||
" code = _codify(word)\n", | ||
" res[code].append(word)\n", | ||
" return res\n", | ||
"\n", | ||
"anagrams = group_anagrams(words)\n", | ||
"anagrams = {k: v for k, v in anagrams.items() if len(v) > 1} # only keep anagrams with more than 1 word\n", | ||
"print(len(anagrams))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 25, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"with open(\"./reasoning_gym/data/anagrams.jsonl\", \"w\") as f:\n", | ||
" for counts, words in anagrams.items():\n", | ||
" letter_counts = {letter: count for letter, count in zip(letters, counts)}\n", | ||
" f.write(json.dumps({\"letter_counts\": letter_counts, \"words\": words}) + \"\\n\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "reasoning_gym", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.11" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
"""Group all anagrams together in a list. | ||
Anagram is a word formed by rearranging the letters of a different word, using all the original letters exactly once. | ||
A popular Leetcode problem: | ||
https://leetcode.com/problems/group-anagrams/description/ | ||
""" | ||
|
||
import json | ||
from collections import defaultdict | ||
from dataclasses import dataclass | ||
from random import Random | ||
from typing import Dict, Optional | ||
|
||
from ..data import get_data_file_path | ||
from ..factory import ProceduralDataset, register_dataset | ||
|
||
MAX_ANAGRAM_GROUPS = 500 | ||
|
||
QUESTION_TEMPLATE = """An anagram is a word formed by rearranging the letters of a different word, using all the original letters exactly once. | ||
Your job is to group the anagrams together. You can return the answer in any order. | ||
Example: | ||
Input: ["eat", "tea", "tan", "ate", "nat", "bat"] | ||
Output: [["bat"], ["nat", "tan"], ["ate", "eat", "tea"]] | ||
Explanation: | ||
- There is no string in the input that can be rearranged to form "bat". | ||
- The strings "nat" and "tan" are anagrams as they can be rearranged to form each other. | ||
Group the following list of words into anagrams: | ||
{words} | ||
""" | ||
|
||
|
||
@dataclass | ||
class GroupAnagramsConfig: | ||
"""Configuration for Group Anagrams dataset generation""" | ||
|
||
anagram_groups: int = 10 # Groups of anagrams present in the input | ||
max_words_per_group: int = 5 # Maximum number of words in a single anagram group | ||
|
||
size: int = 500 # Virtual dataset size | ||
seed: Optional[int] = None | ||
|
||
def validate(self): | ||
"""Validate configuration parameters""" | ||
assert ( | ||
1 <= self.anagram_groups <= MAX_ANAGRAM_GROUPS | ||
), f"anagram_groups must be between 1 and {MAX_ANAGRAM_GROUPS}" | ||
assert 1 <= self.max_words_per_group, "max_words_per_group must be at least 1" | ||
|
||
|
||
class GroupAnagramsDataset(ProceduralDataset): | ||
"""Generates Group Anagrams exercises with configurable difficulty""" | ||
|
||
def __init__(self, config: GroupAnagramsConfig): | ||
super().__init__(config=config, seed=config.seed, size=config.size) | ||
with get_data_file_path("anagrams.jsonl").open() as f: | ||
self.anagrams = [json.loads(line)["words"] for line in f] | ||
|
||
def __len__(self) -> int: | ||
return self.config.size | ||
|
||
def __iter__(self): | ||
self._current_idx = 0 | ||
return self | ||
|
||
def __next__(self): | ||
if self._current_idx >= self.config.size: | ||
raise StopIteration | ||
item = self[self._current_idx] | ||
self._current_idx += 1 | ||
return item | ||
|
||
def _get_anagram_words(self, rng: Random) -> list[str]: | ||
"""Generate a list of words with anagrams""" | ||
words = [] | ||
for sample in rng.sample(self.anagrams, self.config.anagram_groups): | ||
anagrams = rng.sample(sample, rng.randint(1, min(len(sample), self.config.max_words_per_group))) | ||
words.extend(anagrams) | ||
return words | ||
|
||
def _sort_nested_list(self, lst: list[list[str]]) -> list[list[str]]: | ||
"""Sort a nested list of strings""" | ||
return sorted([sorted(sublist) for sublist in lst], key=lambda x: x[0] if x else "") | ||
|
||
def _group_anagrams(self, words: list[str]) -> list[list[str]]: | ||
"""Group anagrams together""" | ||
|
||
def _codify(word): | ||
code = [0] * 26 | ||
for c in word: | ||
code[ord(c) - ord("a")] += 1 | ||
return tuple(code) | ||
|
||
res = defaultdict(list) | ||
for word in words: | ||
code = _codify(word) | ||
res[code].append(word) | ||
|
||
anagrams = list(res.values()) | ||
return self._sort_nested_list(anagrams) | ||
|
||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: | ||
"""Score a single Group Anagrams question""" | ||
reward = 0 | ||
if answer is not None: | ||
try: | ||
answer = json.loads(answer) | ||
oracle = entry["metadata"]["solution"] | ||
answer_str = json.dumps(self._sort_nested_list(answer)) | ||
oracle_str = json.dumps(self._sort_nested_list(oracle)) | ||
if answer_str == oracle_str: | ||
reward = 1 | ||
else: | ||
reward = 0.01 | ||
except Exception: | ||
reward = 0 | ||
return reward | ||
|
||
def __getitem__(self, idx: int) -> dict: | ||
"""Generate a single Group Anagrams question""" | ||
rng = Random(self.seed + idx) | ||
words = self._get_anagram_words(rng) | ||
answer = self._group_anagrams(words) | ||
answer_str = json.dumps(answer) | ||
|
||
return { | ||
"question": QUESTION_TEMPLATE.format(words=json.dumps(words)), | ||
"answer": answer_str, | ||
"metadata": {"words": words, "solution": answer}, | ||
} | ||
|
||
|
||
register_dataset("group_anagrams", GroupAnagramsDataset, GroupAnagramsConfig) |
Oops, something went wrong.