From dfeaac4d178cecead70f4fe63ebbca0126dc715a Mon Sep 17 00:00:00 2001 From: James Aung <129281094+james-aung@users.noreply.github.com> Date: Tue, 19 Mar 2024 07:24:57 -0700 Subject: [PATCH] Add Function Deduction eval (#1492) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Thank you for contributing an eval! ♥️ 🚨 Please make sure your PR follows these guidelines, **failure to follow the guidelines below will result in the PR being closed automatically**. Note that even if the criteria are met, that does not guarantee the PR will be merged nor GPT-4 access be granted. 🚨 **PLEASE READ THIS**: In order for a PR to be merged, it must fail on GPT-4. We are aware that right now, users do not have access, so you will not be able to tell if the eval fails or not. Please run your eval with GPT-3.5-Turbo, but keep in mind as we run the eval, if GPT-4 gets higher than 90% on the eval, we will likely reject it since GPT-4 is already capable of completing the task. We plan to roll out a way for users submitting evals to see the eval performance on GPT-4 soon. Stay tuned! Until then, you will not be able to see the eval performance on GPT-4. **Starting April 10, the minimum eval count is 15 samples, we hope this makes it easier to create and contribute evals.** Also, please note that we're using **Git LFS** for storing the JSON files, so please make sure that you move the JSON file to Git LFS before submitting a PR. Details on how to use Git LFS are available [here](https://git-lfs.com). ## Eval details 📑 ### Eval name Function Deduction ### Eval description We evaluate whether models can effectively employ the scientific method to iterate upon hypotheses until determining one that is correct. In particular, the model attempts to deduce a black-box mathematical function that connects (input, output) it selects in order to gain information. To score highly, the model must ultimately determine the correct result for target inputs, balancing between information-gain and attempting guesses. ### What makes this a useful eval? AI R&D ## Criteria for a good eval ✅ Below are some of the criteria we look for in a good eval. In general, we are seeking cases where the model does not do a good job despite being capable of generating a good response (note that there are some things large language models cannot do, so those would not make good evals). Your eval should be: - [x] Thematically consistent: The eval should be thematically consistent. We'd like to see a number of prompts all demonstrating some particular failure mode. For example, we can create an eval on cases where the model fails to reason about the physical world. - [x] Contains failures where a human can do the task, but either GPT-4 or GPT-3.5-Turbo could not. - [x] Includes good signal around what is the right behavior. This means either a correct answer for `Basic` evals or the `Fact` Model-graded eval, or an exhaustive rubric for evaluating answers for the `Criteria` Model-graded eval. - [x] **Include at least 15 high-quality examples.** If there is anything else that makes your eval worth including, please document it below. ### Unique eval value > Insert what makes your eval high quality that was not mentioned above. (Not required) ## Eval structure 🏗️ Your eval should - [x] Check that your data is in `evals/registry/data/{name}` - [x] Check that your YAML is registered at `evals/registry/evals/{name}.yaml` - [x] Ensure you have the right to use the data you submit via this eval (For now, we will only be approving evals that use one of the existing eval classes. You may still write custom eval classes for your own cases, and we may consider merging them in the future.) ## Final checklist 👀 ### Submission agreement By contributing to Evals, you are agreeing to make your evaluation logic and data under the same MIT license as this repository. You must have adequate rights to upload any data used in an Eval. OpenAI reserves the right to use this data in future service improvements to our product. Contributions to OpenAI Evals will be subject to our usual Usage Policies (). - [x] I agree that my submission will be made available under an MIT license and complies with OpenAI's usage policies. ### Email address validation If your submission is accepted, we will be granting GPT-4 access to a limited number of contributors. Access will be given to the email address associated with the commits on the merged pull request. - [x] I acknowledge that GPT-4 access will only be granted, if applicable, to the email address used for my merged pull request. ### Limited availability acknowledgment We know that you might be excited to contribute to OpenAI's mission, help improve our models, and gain access to GPT-4. However, due to the requirements mentioned above and the high volume of submissions, we will not be able to accept all submissions and thus not grant everyone who opens a PR GPT-4 access. We know this is disappointing, but we hope to set the right expectation before you open this PR. - [x] I understand that opening a PR, even if it meets the requirements above, does not guarantee the PR will be merged nor GPT-4 access be granted. ### Submit eval - [x] I have filled out all required fields of this form - [x] I have used **Git LFS** for the Eval JSON data - [x] (Ignore if not submitting code) I have run `pip install pre-commit; pre-commit install` and have verified that `mypy`, `black`, `isort`, `autoflake` and `ruff` are running when I commit and push Failure to fill out all required fields will result in the PR being closed. ### Eval JSON data Since we are using Git LFS, we are asking eval submitters to add in as many Eval Samples (at least 5) from their contribution here:
View evals in JSON ### Eval ``` # Examples of functions to guess math.floor(x + math.sqrt(x)) math.floor(math.sqrt(x)) math.floor(math.sqrt(x)) - 1 math.floor(math.sqrt(x)) * 2 math.floor(math.sqrt(x) * 2) math.floor(round(x ** (1/3), 8)) x / 2 if not x % 2 else x * 3 x / 2 if not x % 2 else x * 3 + 1 x ** 2 if x % 2 else x ** 3 x / 3 if not x % 3 else x x / 3 if not x % 3 else x * 2 (x + 1) / 3 if x % 3 == 2 else x ```
Co-authored-by: johny-b <33967107+johny-b@users.noreply.github.com> --- evals/elsuite/function_deduction/README.md | 91 ++++++ evals/elsuite/function_deduction/baselines.py | 133 ++++++++ evals/elsuite/function_deduction/eval.py | 302 ++++++++++++++++++ evals/elsuite/function_deduction/prompts.py | 43 +++ .../scripts/dataset/create_dataset.py | 62 ++++ .../scripts/dataset/raw_code.txt | 141 ++++++++ .../function_deduction/scripts/make_plots.py | 256 +++++++++++++++ .../scripts/run_experiments.sh | 27 ++ evals/elsuite/function_deduction/solvers.py | 173 ++++++++++ .../function_deduction/solvers_test.py | 149 +++++++++ .../data/function_deduction/data.jsonl | 3 + evals/registry/evals/function-deduction.yaml | 37 +++ .../registry/solvers/function_deduction.yaml | 192 +++++++++++ 13 files changed, 1609 insertions(+) create mode 100644 evals/elsuite/function_deduction/README.md create mode 100644 evals/elsuite/function_deduction/baselines.py create mode 100644 evals/elsuite/function_deduction/eval.py create mode 100644 evals/elsuite/function_deduction/prompts.py create mode 100644 evals/elsuite/function_deduction/scripts/dataset/create_dataset.py create mode 100644 evals/elsuite/function_deduction/scripts/dataset/raw_code.txt create mode 100644 evals/elsuite/function_deduction/scripts/make_plots.py create mode 100755 evals/elsuite/function_deduction/scripts/run_experiments.sh create mode 100644 evals/elsuite/function_deduction/solvers.py create mode 100644 evals/elsuite/function_deduction/solvers_test.py create mode 100644 evals/registry/data/function_deduction/data.jsonl create mode 100644 evals/registry/evals/function-deduction.yaml create mode 100644 evals/registry/solvers/function_deduction.yaml diff --git a/evals/elsuite/function_deduction/README.md b/evals/elsuite/function_deduction/README.md new file mode 100644 index 0000000000..924b4e47fb --- /dev/null +++ b/evals/elsuite/function_deduction/README.md @@ -0,0 +1,91 @@ +# Function Deduction + +This eval evaluates how well a model can refine a hypothesis according to new evidence and how well it chooses to gather new information. + +In Function Deduction: + +- There is a secret mathematical function that maps an integer to another integer. +- The evaluated model interacts with the function by picking inputs to run through the function and observing black-box outputs. +- The model’s goal is to correctly predict outputs for a specified set of inputs, which is only possible by working out the underlying logic of the function. + +![fd](https://github.com/openai/policy-research-evals/assets/129281094/6c41be74-8237-4bb3-b0fc-13454c20389c) + +## Usage + +Run with: + +``` +oaieval function_deduction +``` + +We suggest using `function_deduction/cot/gpt-4-32k` or `function_deduction/cot/gpt-3.5-turbo-16k` as default choices for `` + +See `evals/registry/solvers/function_deduction.yaml` for a full list of recommended Solvers. + +For more examples of running this eval, see `scripts/run_experiments.sh` + +## Dataset + +The dataset are 100 int -> int functions of varying difficulty. Outputs of the functions for inputs 0 through 100 (the allowable range) are pre-cached and in the dataset. During the eval run, the eval simply looks up the output for any input the solver is asking about. + +## Evaluation Process + +Samples from the dataset are evaluated one-by-one. + +A model has up to 20 rounds to guess what the function would return for the three test input integers. + +In each round the model can either: + +- Ask what the function returns for a given integer (other than the test inputs) *or* +- Make a guess for the values that would be returned by the function for the test inputs. + +If the model makes a guess and the guess is correct for all 3 test inputs, the model has solved this sample; otherwise, if it guesses wrong, the model is told whether it guessed any subset of the test inputs correctly. + +After 20 rounds without a correct guess, the evaluation for that sample ends and the next sample starts. Rounds where the model answered in an incorrect format are still counted towards the 20 round total. + +## Task State + +| Key | Value | +| --- | --- | +| `task_description` | I have an important task for you. Instructions below.

I have a deterministic function `def calc(x: int) -> int:` that accepts an integer in range [0, 100] as an argument and returns an integer.

I want you to correctly determine the result of this function for each of the following inputs: {inputs}

At each turn you must select one of the following 2 options:
Option (A): Gather useful information. Ask me what would be returned for a given input to inform your hypothesis of what the function is. \
This must be an integer in range [0, 100], other than {inputs}.
Option (B): State your guess for the values returned by the function for inputs {inputs}. Your guess must include all three numbers, in this order. \
If your guess is correct, you win. If not, the task will continue.

You have {n_rounds} turns to determine the correct answer, else you lose.

Your goal is to determine the correct answer in as few turns as possible, so choose your actions carefully!
Your final score will be:
* If you make the correct guess: {n_rounds} - number_of_rounds_played
* If you fail to make the correct guess: -{n_rounds}
You should try to score as high as possible.

You should respond either with a single integer in range 0-100 if you decide to take option A or with three integers if you make the final guess (i.e. use option B).
Don't say anything more except for either an integer or 3 integers. | +| `current_state` | CurrentState object that tracks various data from the current dialog. | + +## Metrics + +The below are the key metrics of this eval: + +| Metric | Interpretation | +| --- | --- | +| `adjusted_avg_score` | Combination metric of the below 2 metrics. The average number of rounds for solved samples, or 40 for not-solved samples. (lower is better) | +| `solved_ratio` | The percentage of solved samples (higher is better) | +| `avg_success_rounds` | The average number of rounds for solved samples (lower is better) | + +## Variants + +| Variant | Notes | +| --- | --- | +| Default: `function_deduction.easy` | Default setting as described above. 1 trial per sample | +| `function_deduction.easy.long` | 10 trials per sample | +| `function_deduction.easy.dev5` | Dev set with only 5 samples | +| `function_deduction.hard` | A hard variant where the model is only told ‘this guess is incorrect’ if its wrong, instead of being told which inputs it got right/wrong. | +| `function_deduction.hard.dev5` | Dev set with only 5 samples | + +## Token Usage Estimates + +Below is a rough estimate of the total number of tokens consumed by the default variant: + +| Solver | Tokens | +| --- | --- | +| function_deduction/gpt-4-base | 3 840 000 | +| gpt-4-32k | 880 000 | +| gpt-3.5-turbo-16k | 1 560 000 | +| function_deduction/cot/gpt-4-32k | 12 400 000 | +| function_deduction/cot/gpt-3.5-turbo-16k | 13 230 000 | + +## Version History + +- v0: Initial version released + +## Contribution statement + +Eval design, implementation, and results evaluation were primarily conducted by Jan Betley with contributions from Andrei Alexandru. Report by James Aung. Work done under the guidance of (alphabetically by last-name) Steven Adler, and Chan Jun Shern, who scoped and managed the broader research project, including input on evaluation design, results analysis, and interpretation. diff --git a/evals/elsuite/function_deduction/baselines.py b/evals/elsuite/function_deduction/baselines.py new file mode 100644 index 0000000000..3a81624e03 --- /dev/null +++ b/evals/elsuite/function_deduction/baselines.py @@ -0,0 +1,133 @@ +import logging +import math +from collections import Counter +from pathlib import Path + +import numpy as np +from scipy.stats import entropy + +from evals.data import get_jsonl +from evals.elsuite.function_deduction.eval import CurrentState, Sample +from evals.registry import Registry +from evals.solvers.solver import Solver, SolverResult +from evals.task_state import TaskState + + +class AverageBaseline(Solver): + """ + For given test inputs (x, y, z): + * Ask about values of (x-1, x+1, y-1, y+1, z-1, z+1) + * Make three guesses: round/floor/ceil of average values for neighboring numbers + If didn't succeed in 9 rounds (6x ask 3x guess) - surrender. + + Note: This algorithm fails on the edge cases where, for any of the inputs i: + - i-1 or i+1 is out of range + - i-1 or i+1 are part of the test inputs + In this scenario, the algorithm will fail at the _get_guess stage and skip the guessing. + """ + + def __init__(self, registry=None): + pass + + def _solve(self, task_state: TaskState): + cs: CurrentState = task_state.current_state + + assert len(cs.test_inputs) == 3, "AverageBaseline assumes 3 test inputs" + + if cs.round_ix < 6: + response = self._get_ask(cs.test_inputs, cs.round_ix) + elif 6 <= cs.round_ix < 9: + response = self._get_guess(cs.test_inputs, cs.known_values, cs.round_ix - 6) + else: + response = "I've run out of ideas sorry :(" + return SolverResult(response) + + def _get_guess(self, test_inputs, known_values: dict[int, int], guess_round_ix) -> str: + known_values = { + x: y for x, y in known_values.items() if x - 1 in test_inputs or x + 1 in test_inputs + } + + pairs = [[], [], []] + for i, test_input in enumerate(test_inputs): + try: + lower = known_values[test_input - 1] + higher = known_values[test_input + 1] + except KeyError: + return "Unfortunately I don't have enough data to make a guess, will pass." + pairs[i] = [lower, higher] + + funcs = [round, math.floor, math.ceil] + func = funcs[guess_round_ix] + vals = [func((pair[0] + pair[1]) / 2) for pair in pairs] + return " ".join([str(x) for x in vals]) + + def _get_ask(self, test_inputs, round_ix) -> str: + queries = [] + for x in test_inputs: + queries.append(x - 1) + queries.append(x + 1) + + ask = queries[round_ix] + if ask in test_inputs or ask < 0 or ask > 100: + logging.warning( + f"Invalid query on inputs {test_inputs}: {ask}. AverageBaseline algorithm will fail." + ) + return str(ask) + + +class FullKnowledge(Solver): + """Assuming solver knows all the samples, how well would it perform? + + Two modes - "random", where it selects random integer when asking, + and "best" where it selects the best integer. + + The "best" mode should be close to unbeatable (except for lucky guesses). + """ + + def __init__(self, mode: str, samples_jsonl: str, registry: Registry): + assert mode in ("random", "best"), "mode must be either random or best" + self.mode = mode + self._all_samples = self._get_samples(samples_jsonl, registry._registry_paths[0]) + self._rng = np.random.default_rng() + + def _solve(self, task_state: TaskState): + cs: CurrentState = task_state.current_state + + matching_samples = self._get_matching_samples(cs.known_values) + if len(matching_samples) > 1: + if self.mode == "random": + response = self._get_ask_random(cs.known_values) + else: + response = self._get_ask_best(matching_samples) + else: + sample_values = matching_samples[0].values + result = [sample_values[test_input] for test_input in cs.test_inputs] + response = " ".join([str(x) for x in result]) + return SolverResult(str(response)) + + def _get_matching_samples(self, known_values): + def matches(sample: Sample) -> bool: + for key, val in known_values.items(): + if sample.values[key] != val: + return False + return True + + return [sample for sample in self._all_samples if matches(sample)] + + def _get_ask_best(self, samples): + def get_entropy(x: int) -> float: + values = [sample.values[x] for sample in samples] + counter = Counter(values) + return entropy([val for val in counter.values()]) + + return max(range(0, 101), key=get_entropy) + + def _get_ask_random(self, known_values): + while True: + x = self._rng.integers(0, 100) + if x not in known_values: + return x + + def _get_samples(self, samples_jsonl: str, registry_path: Path): + path = registry_path / "data" / samples_jsonl + return [Sample(**x) for x in get_jsonl(path.as_posix())] diff --git a/evals/elsuite/function_deduction/eval.py b/evals/elsuite/function_deduction/eval.py new file mode 100644 index 0000000000..6542852153 --- /dev/null +++ b/evals/elsuite/function_deduction/eval.py @@ -0,0 +1,302 @@ +import logging +import random +import re +from dataclasses import dataclass, field +from typing import List, Literal, Optional, Tuple, Union + +import numpy as np +import scipy + +import evals +from evals.api import CompletionFn +from evals.elsuite.function_deduction import prompts +from evals.eval import SolverEval +from evals.solvers.solver import Solver +from evals.task_state import Message, TaskState + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class Sample: + sample_ix: int + code: str + complexity: int + range: List[int] + values: List[int] + + +@dataclass +class CurrentState: + """This class tracks all the information from the dialog with the model. + + Some things are tracked to make writing solvers easier. + Other are tracked for metrics. + """ + + n_rounds: int + mode: str + test_inputs: tuple[int, int, int] + success: bool = False + known_values: dict[int, int] = field(default_factory=dict) + negative_known_values: dict[int, int] = field(default_factory=dict) + ask_rounds: int = 0 + guess_rounds: int = 0 + incorrect_format_rounds: int = 0 + parsed_responses: list[tuple[int]] = field(default_factory=list) + + @property + def round_ix(self): + return self.ask_rounds + self.guess_rounds + self.incorrect_format_rounds + + def ask_update(self, input_: int, value: Optional[int]) -> None: + self.ask_rounds += 1 + self.parsed_responses.append((input_,)) + if value is not None: + self.known_values[input_] = value + + def guess_update( + self, guessed_ints: tuple[int, int, int], expected_ints: tuple[int, int, int] + ) -> None: + self.guess_rounds += 1 + self.parsed_responses.append(guessed_ints) + if guessed_ints == expected_ints: + self.success = True + + if self.mode == "easy": + for test, guess, correct in zip(self.test_inputs, guessed_ints, expected_ints): + if guess == correct: + self.known_values[test] = guess + else: + self.negative_known_values[test] = guess + + +class FunctionDeductionEval(SolverEval): + def __init__( + self, + completion_fns: list[CompletionFn], + mode: Literal["easy", "hard"], + n_rounds: int, + n_samples: Optional[int] = None, + n_repeat: int = 3, + failed_sample_rounds: Optional[int] = None, + seed: Optional[int] = None, + samples_jsonl: str = "function_deduction/data.jsonl", + *args, + **kwargs, + ): + super().__init__(completion_fns, seed=seed, samples_jsonl=samples_jsonl, *args, **kwargs) + + self.mode = mode + self.n_rounds = n_rounds + self.n_samples = n_samples + self.n_repeat = n_repeat + + # This is used for the main metric - "how many rounds for a sample that was not solved?" + self.failed_sample_rounds = ( + failed_sample_rounds if failed_sample_rounds is not None else n_rounds * 2 + ) + + def eval_sample(self, solver: Solver, sample: Sample, rng: random.Random): + test_inputs = rng.sample(range(101), 3) + values = sample.values + expected = tuple(sample.values[test_input] for test_input in test_inputs) + + cs = CurrentState(self.n_rounds, self.mode, test_inputs) + task_state = TaskState( + prompts.task_description.format(inputs=test_inputs, n_rounds=self.n_rounds), + current_state=cs, + ) + + for round_ix in range(self.n_rounds): + raw_response = solver(task_state).output + try: + ints = self._parse_raw_response(raw_response) + except ValueError: + cs.incorrect_format_rounds += 1 + answer = prompts.incorrect_format + else: + if len(ints) == 1: + ask = ints[0] + result = values[ask] if ask not in test_inputs else None + cs.ask_update(ask, result) + if result is None: + answer = prompts.test_input_not_allowed.format(inputs=test_inputs) + else: + answer = prompts.new_value.format(in_=ask, out=result) + else: + cs.guess_update(ints, expected) + if cs.success: + break + else: + answer = self._bad_guess_answer(test_inputs, ints, expected) + + task_state.messages += [ + Message("assistant", raw_response), + Message("system", answer), + ] + + evals.record.record_metrics( + sample_ix=sample.sample_ix, + success=cs.success, + num_rounds=cs.round_ix if cs.success else None, + ask_rounds=cs.ask_rounds, + guess_rounds=cs.guess_rounds, + incorrect_format_rounds=cs.incorrect_format_rounds, + repeated_rounds=len(cs.parsed_responses) - len(set(cs.parsed_responses)), + code="lambda x: " + sample.code, + complexity=sample.complexity, + ) + + def run(self, recorder: evals.record.Recorder): + samples = self.get_samples() + + # Add copies according to self.n_repeat + # NOTE: we have copies next to each other -> more convenient when reading in logviz + copied_samples = [sample for sample in samples for _ in range(self.n_repeat)] + logger.info( + f"{len(samples)} unique samples, {self.n_repeat} attempts for each sample, {len(copied_samples)} total samples" + ) + self.eval_all_samples(recorder, copied_samples) + metrics = recorder.get_metrics() + + adjusted_rounds = [x["num_rounds"] or self.failed_sample_rounds for x in metrics] + main_metric = sum(adjusted_rounds) / len(metrics) + result = { + "adjusted_avg_score": main_metric, + "sem_adjusted_avg_score": self._calculate_sem(adjusted_rounds), + } + + result.update(self._get_success_metrics(metrics)) + result.update(self._get_sample_std(metrics)) + for name in ("ask_rounds", "guess_rounds", "incorrect_format_rounds"): + result[f"avg_{name}"] = sum(x[name] for x in metrics) / len(metrics) + result[f"sem_avg_{name}"] = self._calculate_sem([x[name] for x in metrics]) + result.update(self._get_complexity_tests(metrics)) + result.update(self._get_per_complexity_metrics(metrics)) + + return result + + def _calculate_sem(self, values: list) -> float: + return np.std(values) / np.sqrt(len(values)) + + def _get_success_metrics(self, metrics): + success = [x for x in metrics if x["success"]] + return { + "solved_ratio": round(len(success) / len(metrics), 2), + "sem_solved_ratio": self._calculate_sem([x["success"] for x in metrics]), + "solved": len(success), + "samples": len(metrics), + "avg_success_rounds": round(sum(x["num_rounds"] for x in success) / len(success), 2) + if success + else None, + "sem_avg_success_rounds": self._calculate_sem([x["num_rounds"] for x in success]) + if success + else None, + } + + def _get_sample_std(self, metrics): + adjusted = [] + no_failed = [] + solved_ratio_if_any_solved = [] + sample_ixs = set(metric["sample_ix"] for metric in metrics) + for sample_ix in sample_ixs: + sample_metrics = [metric for metric in metrics if metric["sample_ix"] == sample_ix] + sample_adjusted = [ + metric["num_rounds"] or self.failed_sample_rounds for metric in sample_metrics + ] + sample_no_failed = [ + metric["num_rounds"] for metric in sample_metrics if metric["success"] + ] + solved_ratio = sum(1 for metric in sample_metrics if metric["success"]) / len( + sample_metrics + ) + + if len(sample_adjusted) > 1: + adjusted.append(np.std(sample_adjusted)) + if len(sample_no_failed) > 1: + no_failed.append(np.std(sample_no_failed)) + if solved_ratio: + solved_ratio_if_any_solved.append(solved_ratio) + + return { + "avg_sample_rounds_std_adjusted": sum(adjusted) / len(adjusted) if adjusted else None, + "avg_sample_rounds_std_no_failed": sum(no_failed) / len(no_failed) + if no_failed + else None, + # This is just solved_ratio but excluding samples that had no succesful attempt. + # So 1 is full stability (i.e. if sample was solved once, it will be solved always), + # and (1/self.n_repeat) is "no sample was solved more than once" + "solved_ratio_if_any_solved": sum(solved_ratio_if_any_solved) + / len(solved_ratio_if_any_solved) + if solved_ratio_if_any_solved + else None, + } + + def _get_complexity_tests(self, metrics): + solved = [x["complexity"] for x in metrics if x["success"]] + not_solved = [x["complexity"] for x in metrics if not x["success"]] + result = { + "solved_avg_complexity": sum(solved) / len(solved) if solved else None, + "not_solved_avg_complexity": sum(not_solved) / len(not_solved) if not_solved else None, + } + + # This tests if solved have lower complexity than non-solved + if solved and not_solved: + _, p_value = scipy.stats.mannwhitneyu(solved, not_solved, alternative="less") + else: + p_value = None + result["solved_or_not_mann_whitney_u_p_value"] = p_value + + # TODO: add more complexity-related metrics, such as correlation or linear regression coefficient. + # Leaving this for the future because we might want to change how the complexity is calculated, + # or generally improve the concept somehow. + + return result + + def _get_per_complexity_metrics(self, all_metrics): + complexity_values = sorted(x["complexity"] for x in all_metrics) + result = {} + for complexity in complexity_values: + metrics = [x for x in all_metrics if x["complexity"] == complexity] + result[f"complexity_{complexity}"] = self._get_success_metrics(metrics) + return result + + def _parse_raw_response(self, response: str) -> Union[Tuple[int], Tuple[int, int, int]]: + # Remove all non-numbers first. This way we accept also e.g. "1, 2, 3", "[1, 2, 3]", '"1", "2", "3"' etc. + response = re.sub(r"[^0-9\s-]", "", response) + + vals = tuple(int(x) for x in response.split()) + if len(vals) not in (1, 3): + raise ValueError("Expected 1 or 3 integers") + if len(vals) == 1 and not 0 <= vals[0] <= 100: + raise ValueError("Single int should be between 0 and 100") + return vals + + def _bad_guess_answer(self, test_inputs, guessed, expected) -> str: + correct = [test_inputs[i] for i in range(0, 3) if guessed[i] == expected[i]] + incorrect = [x for x in test_inputs if x not in correct] + assert incorrect, "This is not a bad answer" + + if self.mode == "hard": + return "This is not the correct answer. At least one of the values is wrong." + elif self.mode == "easy": + if len(correct) == 0: + return "All numbers are wrong." + elif len(correct) == 1: + return f"Your guess is correct for {correct[0]} and incorrect for {incorrect[0]} and {incorrect[1]}" + elif len(correct) == 2: + return f"Your guess is correct for {correct[0]} and {correct[1]} and incorrect for {incorrect[0]}" + else: + raise ValueError(f"Unknown mode {self.mode}") + + def get_samples(self) -> List[Sample]: + samples = super().get_samples() + + if self.n_samples is not None: + assert ( + len(samples) >= self.n_samples + ), f"Can't get {self.n_samples} samples from a dataset with {len(samples)} samples" + np.random.default_rng(seed=self.seed).shuffle(samples) + samples = samples[: self.n_samples] + return [Sample(**sample_dict) for sample_dict in samples] diff --git a/evals/elsuite/function_deduction/prompts.py b/evals/elsuite/function_deduction/prompts.py new file mode 100644 index 0000000000..948e14079f --- /dev/null +++ b/evals/elsuite/function_deduction/prompts.py @@ -0,0 +1,43 @@ +task_description = """\ +I have an important task for you. Instructions below. + +I have a deterministic function `def calc(x: int) -> int:` that accepts an integer in range [0, 100] as an argument and returns an integer. + +I want you to correctly determine the result of this function for each of the following inputs: {inputs} + +At each turn you must select one of the following 2 options: +Option (A): Gather useful information. Ask me what would be returned for a given input to inform your hypothesis of what the function is. \ +This must be an integer in range [0, 100], other than {inputs}. +Option (B): State your guess for the values returned by the function for inputs {inputs}. Your guess must include all three numbers, in this order. \ +If your guess is correct, you win. If not, the task will continue. + +You have {n_rounds} turns to determine the correct answer, else you lose. + +Your goal is to determine the correct answer in as few turns as possible, so choose your actions carefully! +Your final score will be: +* If you make the correct guess: {n_rounds} - number_of_rounds_played +* If you fail to make the correct guess: -{n_rounds} +You should try to score as high as possible. + +You should respond either with a single integer in range 0-100 if you decide to take option A or with three integers if you make the final guess (i.e. use option B). +Don't say anything more except for either an integer or 3 integers. +""" +# TODO: Include in the task description whether we're in the easy or hard mode? +# (i.e. whether we'll be told about correctly guessed numbers (if any) or not). + +incorrect_format = "You responded in an incorrect format. Respond with either a single integer in range 0-100 or 3 integers. Just one number or three numbers, nothing more." +new_value = "calc({in_}) = {out}" +test_input_not_allowed = "Numbers {inputs} are not allowed" + +cot_template_first_round = """This is the first round. Think out loud about a general strategy of solving tasks like this.""" +cot_template_later_rounds = """\ +This is round {round_ix} out of {num_rounds}. + +Summary of the information you have already gathered: +{summary} + +Think out loud about the following questions: +* Do you have any hypothesis on what this function might be doing? +* If yes, should you try to test it (how?), or just use it to calculate the answer? +* If not, what additional information should you gather to be able to formulate a hypothesis? +""" diff --git a/evals/elsuite/function_deduction/scripts/dataset/create_dataset.py b/evals/elsuite/function_deduction/scripts/dataset/create_dataset.py new file mode 100644 index 0000000000..931e1cc27a --- /dev/null +++ b/evals/elsuite/function_deduction/scripts/dataset/create_dataset.py @@ -0,0 +1,62 @@ +import argparse +import dis +import json +import math + +DEFAULT_RANGE = [0, 100] # inclusive + + +def get_func_from_code(code): + return lambda x: eval(code, {"math": math, "x": x}) + + +def get_complexity(code: str) -> int: + # NOTE: this is quite ugly, but should be good enough for dataset-creating code + code = "global func_name\ndef func_name(x): return " + code + exec(code) + return len(list(dis.get_instructions(func_name))) + + +def create_dataset(out_file, in_file): + samples = [] + + for line in in_file: + line = line.strip() + + if not line or line.startswith("#"): + continue + + func = get_func_from_code(line) + values = list(int(func(x)) for x in range(DEFAULT_RANGE[0], DEFAULT_RANGE[1] + 1)) + samples.append( + { + "code": line, + "complexity": get_complexity(line), + "range": DEFAULT_RANGE, + "values": values, + } + ) + + # Ensure we don't have duplicates - they might be different functions, but if they return the same + # value for every input in the DEFAULT_RANGE then they are in fact the same sample. + for sample_ix, sample in enumerate(samples): + for other_sample in samples[sample_ix + 1 :]: + if sample["values"] == other_sample["values"]: + raise ValueError( + f"Samples {sample['code']} and {other_sample['code']} are indistinguishable" + ) + + samples.sort(key=lambda x: x["complexity"]) + + for i, sample in enumerate(samples): + sample = dict(sample_ix=i, **sample) + json.dump(sample, out_file) + out_file.write("\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--out", type=argparse.FileType("w"), required=True) + parser.add_argument("--in", dest="in_", type=argparse.FileType("r"), default="raw_code.txt") + args = parser.parse_args() + create_dataset(out_file=args.out, in_file=args.in_) diff --git a/evals/elsuite/function_deduction/scripts/dataset/raw_code.txt b/evals/elsuite/function_deduction/scripts/dataset/raw_code.txt new file mode 100644 index 0000000000..ff03a0c76e --- /dev/null +++ b/evals/elsuite/function_deduction/scripts/dataset/raw_code.txt @@ -0,0 +1,141 @@ +# Lines starting with '#' or empty are ignored. +# Every other line is code for a single sample. +# This file is parsed by create_datset.py script +# (-> is not accessed when the eval is running). + +# Single operation +x +x * 2 +x * 27 +-x +x * -2 +x * -19 +math.floor(x * 1.5) +math.floor(x * 8.5) +math.floor(x / 2) +math.floor(x / 10) +math.ceil(x / 2) +round(x / 10) +math.ceil(x / 10) +x + 1 +x + 17 +x - 1 +x - 29 +7 - x +x ** 2 +x ** 3 + +# Small set of values +7 +7 if x % 2 else 17 +x % 3 +x % 7 +x % 10 +int(x % 3 == 1) +int(x % 3 == 2) +int(x % 3 == 1) * 7 +int(x % 3 == 2) * 18 +int(x < 32) +int(x % 8 < 4) + +# Threshold +min(x, 30) +max(x, 30) +min(x * 2, 70) +max(x * 2, 70) +x * 2 if x < 50 else x +x + 7 if x < 50 else x - 7 +x + 50 if x < 50 else 100 - x +x * 2 if x > 40 else x * 3 +3 if 30 < x < 70 else 4 +min(1000000, 2 ** x) + +# Multiple operations +math.floor(x + math.sqrt(x)) +math.floor(math.sqrt(x)) +math.floor(math.sqrt(x)) - 1 +math.floor(math.sqrt(x)) * 2 +math.floor(math.sqrt(x) * 2) +math.floor(round(x ** (1/3), 8)) +x / 2 if not x % 2 else x * 3 +x / 2 if not x % 2 else x * 3 + 1 +x ** 2 if x % 2 else x ** 3 +x / 3 if not x % 3 else x +x / 3 if not x % 3 else x * 2 +(x + 1) / 3 if x % 3 == 2 else x +x ** 2 - 10 +x ** 3 - x ** 2 +x ** 2 * 2 +x * (x - 1) +x * (x - 1) * (x - 2) +x * (x + 1) / 2 +5 - (x % 5) +10 - (x % 10) +16 - (x % 16) +x - x % 6 +x - x % 15 +x - x % 10 +x + x % 10 +x + x % 4 +x + x // 10 +x + x // 8 +x // 10 + x % 2 +(x + 5) * 3 +(x + 2) * 7 +(2 * x) ** 2 + + +# Math, sin, cos etc +round(math.sin(x)) +round(math.sin(x * 0.5 * math.pi)) +round(math.sin(x * 0.25 * math.pi) * 10) +round(math.sin(x * 0.1 * math.pi) * 10) +round(math.cos(x)) +round(math.cos(x * 0.5 * math.pi)) +round(math.cos(x * 0.25 * math.pi) * 10) +round(math.cos(x * 0.1 * math.pi) * 10) + +# Is prime number? +int(x > 1 and all(x % i for i in range(2, x))) +x if x > 1 and all(x % i for i in range(2, x)) else x + 1 + +# Is perfect square? +int(int(x**0.5)**2 == x) + +# Divisors - number / sum +sum(1 for i in range(1, x + 1) if not x % i) +sum(i for i in range(1, x + 1) if not x % i) + +# Reverse digits +int(str(x)[::-1]) +abs(x - int(str(x)[::-1])) +x + int(str(x)[::-1]) + +# Sum of digits +sum(int(d) for d in str(x)) +x + sum(int(d) for d in str(x)) +int(sum(int(d) for d in str(x)) % 10) + +# Count odd/even digits +sum(1 for d in str(x) if int(d) % 2) +sum(1 for d in str(x) if not int(d) % 2) + +# Multiple digits +0 if x < 10 else (x % 10) * (x // 10) + +# Higher vs lower digit +0 if x < 10 else max(int(d) for d in str(x)) - min(int(d) for d in str(x)) + +# Other +bin(x).count("1") +x | 1 +int(str(x) == str(x)[::-1]) +x * int(str(x)[-1]) + +# More ideas: convert to binary +# int(bin(x)[2:]) +# int(bin(~x)[3:]) +# int(bin(x * 2)[2:]) + +# More ideas: highest divisor lower than x? +# 0 if x == 0 else max(1 for i in range(1, x) if not x % i) diff --git a/evals/elsuite/function_deduction/scripts/make_plots.py b/evals/elsuite/function_deduction/scripts/make_plots.py new file mode 100644 index 0000000000..4c8f5f5e78 --- /dev/null +++ b/evals/elsuite/function_deduction/scripts/make_plots.py @@ -0,0 +1,256 @@ +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns + +from evals.utils import log_utils + +palette = { + "Average Baseline": "blue", + "Full Knowledge Best": "blue", + "Full Knowledge Random": "blue", + + "Human": "steelblue", + + "gpt-4-32k": "purple", + "gpt-4-32k w CoT": "purple", + + "gpt-4-base w Few-shot": "orange", + "gpt-4-base w CoT and Few-shot": "orange", + + "gpt-3.5-turbo-16k": "green", + "gpt-3.5-turbo-16k w CoT": "green", + + "gemini-pro": "peru", + "gemini-pro w CoT": "peru", + + "llama-2-13b-chat": "brown", + "llama-2-13b-chat w CoT": "brown", + + "llama-2-70b-chat": "maroon", + "llama-2-70b-chat w CoT": "maroon", + + "mixtral-8x7b-instruct": "grey", + "mixtral-8x7b-instruct w CoT": "grey", +} + +solver_to_name = { + "function_deduction/full_knowledge_best": "Full Knowledge Best", + "function_deduction/full_knowledge_random": "Full Knowledge Random", + "function_deduction/average_baseline": "Average Baseline", + + "human_cli": "Human", + + "gpt-4-32k": "gpt-4-32k", + "function_deduction/cot/gpt-4-32k": "gpt-4-32k w CoT", + + "function_deduction/gpt-4-base": "gpt-4-base w Few-shot", + "function_deduction/cot/gpt-4-base": "gpt-4-base w CoT and Few-shot", + + "gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k", + "function_deduction/cot/gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k w CoT", + + "generation/direct/gemini-pro": "gemini-pro", + "function_deduction/cot/gemini-pro": "gemini-pro w CoT", + + "generation/direct/llama-2-13b-chat": "llama-2-13b-chat", + "function_deduction/cot/llama-2-13b-chat": "llama-2-13b-chat w CoT", + + "generation/direct/llama-2-70b-chat": "llama-2-70b-chat", + "function_deduction/cot/llama-2-70b-chat": "llama-2-70b-chat w CoT", + + "generation/direct/mixtral-8x7b-instruct": "mixtral-8x7b-instruct", + "function_deduction/cot/mixtral-8x7b-instruct": "mixtral-8x7b-instruct w CoT", +} + +rename_columns = { + "adjusted_avg_rounds": "adjusted_avg_score", + "sem_adjusted_avg_rounds": "sem_adjusted_avg_score", +} + + +def extract_final_reports( + datadir: Path, rename_solvers: dict, rename_columns: dict +) -> pd.DataFrame: + df_rows = [] + for path, results in sorted(list(log_utils.get_final_results_from_dir(datadir).items())): + spec = log_utils.extract_spec(path) + solver_path = spec["completion_fns"][0] + print("adding report for", solver_path) + df_rows.append( + { + "solver": rename_solvers.get(solver_path, solver_path), + **{rename_columns.get(k, k): v for k, v in results.items()}, + } + ) + df = pd.DataFrame(df_rows) + return df + + +def make_plot( + df, + x_column: str, + y_column: str, + x_err_column: str, + title: str, + xlabel: str, + ylabel: str, + out_path: Path, +): + # Avg rounds until success (failure counts as 40) + plt.figure(figsize=(10, 6)) + ax = sns.barplot( + x=x_column, + y=y_column, + data=df, + xerr=df[x_err_column] * 1.96, + palette=palette, + ) + + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.title(title) + plt.grid(axis="x") + plt.tight_layout() + + # Expanding the x-axis limit + x_lim = ax.get_xlim() + ax.set_xlim([x_lim[0], x_lim[1] * 1.05]) # Increase the upper limit by 5% + + # Annotating each bar with its value + for p in ax.patches: + width = p.get_width() + ax.text( + width + x_lim[1] * 0.02, # x position of text + p.get_y() + p.get_height() / 2, # y position of text + "{:.1f}".format(width), # text to be shown + va="center", + ) # vertical alignment + + plt.savefig(out_path) + return + + +def make_ask_guess_incorrect_plot(df, out_path: Path): + # Ask/Guess/Incorrect + + ask_guess_incorrect_data = { + "solver": df["solver"], + "Ask": df["avg_ask_rounds"], + "SEM Average Ask Rounds": df["sem_avg_ask_rounds"], + "Guess": df["avg_guess_rounds"], + "SEM Average Guess Rounds": df["sem_avg_guess_rounds"], + "Incorrect Format": df["avg_incorrect_format_rounds"], + "SEM Average Incorrect Format Rounds": df["sem_avg_incorrect_format_rounds"], + } + + agi_palette = { + "Ask": "blue", + "Guess": "pink", + "Incorrect Format": "red", + } + + ask_guess_incorrect_df = pd.DataFrame(ask_guess_incorrect_data) + + # Melting the DataFrame to make it suitable for seaborn's factorplot + melted_df = pd.melt( + ask_guess_incorrect_df, + id_vars="solver", + value_vars=["Ask", "Guess", "Incorrect Format"], + var_name="Round Type", + value_name="Average Rounds", + ) + + # Generating the plot for Average Ask/Guess/Incorrect Format Rounds + plt.figure(figsize=(14, 14)) + ax = sns.barplot( + x="Average Rounds", y="solver", hue="Round Type", data=melted_df, palette=agi_palette + ) + + plt.xlabel("Average Number of Rounds") + plt.ylabel("Solver") + plt.title("Distribution of Type of Responses by Model") + plt.grid(axis="x") + plt.legend(title="Response Type") + plt.tight_layout() + + # Expanding the x-axis limit + x_lim = ax.get_xlim() + ax.set_xlim([x_lim[0], x_lim[1] * 1.05]) # Increase the upper limit by 5% + + # Annotating each bar with its value + for p in ax.patches: + width = p.get_width() + ax.text( + width + 0.1, # x position of text + p.get_y() + p.get_height() / 2, # y position of text + "{:.1f}".format(width), # text to be shown + va="center", + ) # vertical alignment + + plt.savefig(out_path) + return + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--log-dir", "-d", type=str, required=True) + parser.add_argument("--out-dir", "-o", type=str, default="./outputs") + args = parser.parse_args() + log_dir = Path(args.log_dir) + out_dir = Path(args.out_dir) + + df = extract_final_reports(log_dir, solver_to_name, rename_columns) + + # Drop all columns named "complexity*" + df = df[df.columns.drop(list(df.filter(regex="complexity")))] + + # Creating a new DataFrame with the desired order + ordered_df = df.set_index("solver").loc[list(solver_to_name.values())].reset_index() + print(ordered_df) + + make_plot( + df=ordered_df, + x_column="adjusted_avg_score", + y_column="solver", + x_err_column="sem_adjusted_avg_score", + title="Adjusted Average Score (Lower is Better)", + xlabel="Adjusted Average Score", + ylabel="Solver", + out_path=out_dir / "avg_adjusted_score.png", + ) + + ordered_df["solved_ratio"] = 100 * ordered_df["solved_ratio"] + ordered_df["sem_solved_ratio"] = 100 * ordered_df["sem_solved_ratio"] + make_plot( + df=ordered_df, + x_column="solved_ratio", + y_column="solver", + x_err_column="sem_solved_ratio", + title="Solved Samples Ratio (Higher is Better)", + xlabel="Solved Ratio (%)", + ylabel="Solver", + out_path=out_dir / "solved_ratio.png", + ) + + make_plot( + df=ordered_df, + x_column="avg_success_rounds", + y_column="solver", + x_err_column="sem_avg_success_rounds", + title="Average Number of Rounds for Solved Samples (Lower is Better)", + xlabel="No. of Rounds", + ylabel="Solver", + out_path=out_dir / "avg_success_rounds.png", + ) + + make_ask_guess_incorrect_plot( + df=ordered_df, + out_path=out_dir / "ask_guess_incorrect.png", + ) + + +if __name__ == "__main__": + main() diff --git a/evals/elsuite/function_deduction/scripts/run_experiments.sh b/evals/elsuite/function_deduction/scripts/run_experiments.sh new file mode 100755 index 0000000000..4e67f5c7be --- /dev/null +++ b/evals/elsuite/function_deduction/scripts/run_experiments.sh @@ -0,0 +1,27 @@ + +logdir=./logs +timestamp=$(date +%Y%m%d_%H%M%S) +logpathbase="$logdir/$timestamp" + +echo Running experiments and logging to $logpathbase + +# Baselines +oaieval function_deduction/average_baseline function_deduction.easy --record_path "$logpathbase/average_baseline.log" +oaieval function_deduction/full_knowledge_best function_deduction.easy --record_path "$logpathbase/full_knowledge_best.log" +oaieval function_deduction/full_knowledge_random function_deduction.easy --record_path "$logpathbase/full_knowledge_random.log" --extra_eval_params n_repeat=100 + +declare -a SOLVERS=( + gpt-3.5-turbo-16k + gpt-4-32k + function_deduction/cot/gpt-3.5-turbo-16k + function_deduction/cot/gpt-4-32k + function_deduction/gpt-4-base + function_deduction/cot/gpt-4-base +) + +# Models +for solver in "${SOLVERS[@]}" +do + log_name=${solver//\//-} + oaieval $solver function_deduction.easy --record_path "$logpathbase/$log_name.log" +done diff --git a/evals/elsuite/function_deduction/solvers.py b/evals/elsuite/function_deduction/solvers.py new file mode 100644 index 0000000000..4830afe34a --- /dev/null +++ b/evals/elsuite/function_deduction/solvers.py @@ -0,0 +1,173 @@ +from typing import Any + +from evals.elsuite.function_deduction import prompts +from evals.elsuite.function_deduction.eval import CurrentState +from evals.solvers.nested.cot_solver import CoTSolver +from evals.solvers.nested.hhh_solver import HHHSolver +from evals.solvers.solver import SolverResult, SolverSpec +from evals.task_state import Message, TaskState + + +class CustomCoT(CoTSolver): + def __init__( + self, + cot_solver: SolverSpec, + extract_solver: SolverSpec, + persistent_memory: bool = True, + registry: Any = None, + ): + super().__init__( + cot_solver=cot_solver, + extract_solver=extract_solver, + persistent_memory=persistent_memory, + ) + + def cot_template(self, task_state: TaskState) -> str: + round_ix = task_state.current_state.round_ix + if round_ix == 0: + return prompts.cot_template_first_round + else: + summary = self._get_summary(task_state.current_state) + return prompts.cot_template_later_rounds.format( + round_ix=round_ix + 1, # displayed round number starts from 1 + num_rounds=task_state.current_state.n_rounds, + summary=summary, + ) + + def _get_summary(self, current_state: CurrentState) -> str: + rows = [] + for key, val in sorted(current_state.known_values.items()): + rows.append(f"calc({key}) = {val}") + + negative_rows = [] + for key, val in sorted(current_state.negative_known_values.items()): + negative_rows.append(f"calc({key}) != {val}") + + parts = [] + if rows: + parts.append("\n".join(rows)) + if negative_rows: + msg = "Information from your incorrect guesses:\n" + parts.append(msg + "\n".join(negative_rows)) + + if not parts: + return "You don't know anything yet." + else: + return "\n\n".join(parts) + + +class BaseModelSolver(HHHSolver): + def _solve(self, task_state: TaskState): + task_state = TaskState( + task_state.task_description, + self._few_shot_messages() + task_state.messages, + task_state.current_state, + ) + result = super()._solve(task_state) + result = result.output.splitlines()[0] + return SolverResult(result) + + def _few_shot_messages(self) -> list[Message]: + role = "system" + messages = [ + (role, "I have a hidden function. What is your first action?"), + ("assistant", "40"), + (role, "calc(40) = 160"), + ("assistant", "52"), + (role, "calc(52) = 204"), + ("assistant", "68 144 272"), + (role, "Correct guess!"), + (role, "I now have a new function. Forget about the previous one, we start again."), + ] + return [Message(*row) for row in messages] + + +class BaseModelCoTSolver(CustomCoT): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def cot_solver(self): + # This is a hack to enable us to postprocess the output during the CoT generation step: + # Here, we always keep just first line of the CoT generation, otherwise the base model + # goes on to copy the extraction prompt and confuses itself. + # TODO: Once we have solvers with proper built-in support for output postprocessors, + # update this to use that instead. + def cot_solver(task_state): + result = self.get_solver("cot_solver")(task_state).output + result = result.splitlines()[0] + return SolverResult(result) + + return cot_solver + + def _solve(self, task_state: TaskState): + task_state = TaskState( + task_state.task_description, + self._few_shot_messages(task_state.current_state) + task_state.messages, + task_state.current_state, + ) + + result = super()._solve(task_state) + result = result.output.splitlines()[0] + + # Fix the interaction history so that we can have persistent_memory = True + self.interaction_cache.last_interaction.messages[-1] = Message("assistant", result) + + return SolverResult(result) + + def _few_shot_messages(self, current_state) -> list[Message]: + # This is a bit hackish, but this way we can use self.cot_template (defined on CustomCoT), + # -> we'll have exactly the same system prompts in few shot part as we'd have in the eval + cs = CurrentState(current_state.n_rounds, current_state.mode, current_state.test_inputs) + task_state = TaskState("", [], cs) + + role = "system" + + messages = [ + (role, "I have a hidden function. What is your first action?"), + ] + messages.append((role, self.cot_template(task_state))) + messages.append( + ( + "assistant", + ( + "I don't know anything about the function yet. I should start with gathering some information, " + "e.g. by asking about the return value for 40." + ), + ) + ) + messages.append((role, self.extract_template(task_state))) + messages.append(("assistant", "40")) + messages.append((role, "calc(40) = 160")) + cs.ask_update(40, 160) + + messages.append((role, self.cot_template(task_state))) + messages.append( + ( + "assistant", + "Perhaps this function does multiplication by 4? That's not certain yet, so let's try some other number, e.g. 52.", + ) + ) + messages.append((role, self.extract_template(task_state))) + messages.append(("assistant", "52")) + messages.append((role, "calc(52) = 204")) + cs.ask_update(52, 204) + + messages.append((role, self.cot_template(task_state))) + messages.append( + ( + "assistant", + ( + "Now we have two results where the ouput is the input times 4. It seems that the function multiplies by 4. " + "I will make the guess now. 17 * 4 = 68, 36 * 4 = 144 and 68 * 4 = 272, so my guess will be 68 144 272." + ), + ) + ) + messages.append((role, self.extract_template(task_state))) + messages.append(("assistant", "68 144 272")) + messages.append((role, "Correct guess!")) + messages.append( + (role, "I now have a new function. Forget about the previous one, we start again.") + ) + + return [Message(*row) for row in messages] diff --git a/evals/elsuite/function_deduction/solvers_test.py b/evals/elsuite/function_deduction/solvers_test.py new file mode 100644 index 0000000000..8fadec107f --- /dev/null +++ b/evals/elsuite/function_deduction/solvers_test.py @@ -0,0 +1,149 @@ +from evals.elsuite.function_deduction.eval import CurrentState +from evals.elsuite.function_deduction.prompts import ( + cot_template_first_round, + cot_template_later_rounds, +) +from evals.elsuite.function_deduction.solvers import BaseModelCoTSolver, CustomCoT +from evals.solvers.solver import SolverSpec +from evals.task_state import Message, TaskState + +dummy_solver_spec = SolverSpec( + { + "class": "evals.solvers.solver:DummySolver", + "args": {}, + } +) + +GUESS_INPUT = 7 +ANSWER = 0 +N_ROUNDS = 10 +ROUNDS_SIMULATED = 2 +MODE = "easy" +TEST_INPUTS = (10, 20, 30) + + +def simulate_dummy_game(solver): + # Init state + task_description = "" # Not used + msgs = [] + cs = CurrentState( + n_rounds=N_ROUNDS, + mode=MODE, + test_inputs=TEST_INPUTS, + ) + + # ROUND 1 + solver_result = solver( + TaskState( + task_description=task_description, + messages=msgs, + current_state=cs, + ) + ) + + msgs.append(Message("assistant", solver_result.output)) + msgs.append(Message("system", f"The answer to your query is {ANSWER}")) + cs.ask_update(GUESS_INPUT, ANSWER) # Collect data for input=7 + + # ROUND 2 + solver_result = solver( + TaskState( + task_description=task_description, + messages=msgs, + current_state=cs, + ) + ) + return solver + + +def test_custom_cot(): + solver = CustomCoT(dummy_solver_spec, dummy_solver_spec) + simulate_dummy_game(solver) + + # Check that the customized CoT generation prompts appear as expected + # (and that the persistent memory in fact persists) + solver_private_memory = solver.interaction_cache.last_interaction.messages + assert solver_private_memory[0].content == cot_template_first_round + assert solver_private_memory[2].content == solver._extract_template + assert solver_private_memory[5].content == cot_template_later_rounds.format( + round_ix=ROUNDS_SIMULATED, + num_rounds=N_ROUNDS, + summary=f"calc({GUESS_INPUT}) = {ANSWER}", + ) + assert solver_private_memory[7].content == solver._extract_template + + +def test_base_model_cot_solver(): + solver = BaseModelCoTSolver(dummy_solver_spec, dummy_solver_spec) + simulate_dummy_game(solver) + + # Check that the memory contains the few-shot prompts + # followed by the customized CoT generation prompts + solver_private_memory = solver.interaction_cache.last_interaction.messages + + expected_few_shot_msgs = [ + Message(role="system", content="I have a hidden function. What is your first action?"), + Message( + role="system", + content="This is the first round. Think out loud about a general strategy of solving tasks like this.", + ), + Message( + role="assistant", + content="I don't know anything about the function yet. I should start with gathering some information, e.g. by asking about the return value for 40.", + ), + Message( + role="system", + content="Given the above reasoning, the answer in the format requested by the question is:", + ), + Message(role="assistant", content="40"), + Message(role="system", content="calc(40) = 160"), + Message( + role="system", + content="This is round 2 out of 10.\n\nSummary of the information you have already gathered:\ncalc(40) = 160\n\nThink out loud about the following questions:\n* Do you have any hypothesis on what this function might be doing?\n* If yes, should you try to test it (how?), or just use it to calculate the answer?\n* If not, what additional information should you gather to be able to formulate a hypothesis?\n", + ), + Message( + role="assistant", + content="Perhaps this function does multiplication by 4? That's not certain yet, so let's try some other number, e.g. 52.", + ), + Message( + role="system", + content="Given the above reasoning, the answer in the format requested by the question is:", + ), + Message(role="assistant", content="52"), + Message(role="system", content="calc(52) = 204"), + Message( + role="system", + content="This is round 3 out of 10.\n\nSummary of the information you have already gathered:\ncalc(40) = 160\ncalc(52) = 204\n\nThink out loud about the following questions:\n* Do you have any hypothesis on what this function might be doing?\n* If yes, should you try to test it (how?), or just use it to calculate the answer?\n* If not, what additional information should you gather to be able to formulate a hypothesis?\n", + ), + Message( + role="assistant", + content="Now we have two results where the ouput is the input times 4. It seems that the function multiplies by 4. I will make the guess now. 17 * 4 = 68, 36 * 4 = 144 and 68 * 4 = 272, so my guess will be 68 144 272.", + ), + Message( + role="system", + content="Given the above reasoning, the answer in the format requested by the question is:", + ), + Message(role="assistant", content="68 144 272"), + Message(role="system", content="Correct guess!"), + Message( + role="system", + content="I now have a new function. Forget about the previous one, we start again.", + ), + ] + assert solver_private_memory[: len(expected_few_shot_msgs)] == expected_few_shot_msgs + assert ( + solver_private_memory[len(expected_few_shot_msgs) + 0].content == cot_template_first_round + ) + assert ( + solver_private_memory[len(expected_few_shot_msgs) + 2].content == solver._extract_template + ) + assert solver_private_memory[ + len(expected_few_shot_msgs) + 5 + ].content == cot_template_later_rounds.format( + round_ix=ROUNDS_SIMULATED, + num_rounds=N_ROUNDS, + summary=f"calc({GUESS_INPUT}) = {ANSWER}", + ) + assert ( + solver_private_memory[len(expected_few_shot_msgs) + 7].content == solver._extract_template + ) diff --git a/evals/registry/data/function_deduction/data.jsonl b/evals/registry/data/function_deduction/data.jsonl new file mode 100644 index 0000000000..bded32c52b --- /dev/null +++ b/evals/registry/data/function_deduction/data.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb7cd13c1f67a7be8d153de26c7436a805035053f5497b77296e3f3615023e86 +size 50468 diff --git a/evals/registry/evals/function-deduction.yaml b/evals/registry/evals/function-deduction.yaml new file mode 100644 index 0000000000..337856cd72 --- /dev/null +++ b/evals/registry/evals/function-deduction.yaml @@ -0,0 +1,37 @@ +function_deduction: + id: function_deduction.easy + metrics: [adjusted_avg_rounds, solved_ratio, solved, samples, avg_success_rounds, avg_sample_rounds_std_adjusted, avg_sample_rounds_std_no_failed, solved_ratio_if_any_solved, avg_ask_rounds, avg_guess_rounds, avg_incorrect_format_rounds, solved_avg_complexity, not_solved_avg_complexity, solved_or_not_mann_whitney_u_p_value, sem_adjusted_avg_rounds, sem_avg_success_rounds, sem_avg_guess_rounds, sem_avg_incorrect_format_rounds] + description: Test a model's ability to deduce unknown functions + +function_deduction.easy: + class: evals.elsuite.function_deduction.eval:FunctionDeductionEval + args: + mode: easy + n_rounds: 20 + +function_deduction.easy.long: + class: evals.elsuite.function_deduction.eval:FunctionDeductionEval + args: + mode: easy + n_rounds: 20 + n_repeat: 10 + +function_deduction.easy.dev5: + class: evals.elsuite.function_deduction.eval:FunctionDeductionEval + args: + mode: easy + n_rounds: 20 + n_samples: 5 + +function_deduction.hard: + class: evals.elsuite.function_deduction.eval:FunctionDeductionEval + args: + mode: hard + n_rounds: 20 + +function_deduction.hard.dev5: + class: evals.elsuite.function_deduction.eval:FunctionDeductionEval + args: + mode: hard + n_rounds: 20 + n_samples: 5 diff --git a/evals/registry/solvers/function_deduction.yaml b/evals/registry/solvers/function_deduction.yaml new file mode 100644 index 0000000000..9b8837851b --- /dev/null +++ b/evals/registry/solvers/function_deduction.yaml @@ -0,0 +1,192 @@ +# OS CHAIN OF THOUGHT +function_deduction/cot/llama-2-13b-chat: + class: evals.elsuite.function_deduction.solvers:CustomCoT + args: + cot_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-13b-chat-hf + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-13b-chat-hf + extra_options: + temperature: 0 + max_tokens: 32 + +function_deduction/cot/llama-2-70b-chat: + class: evals.elsuite.function_deduction.solvers:CustomCoT + args: + cot_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-70b-chat-hf + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-70b-chat-hf + extra_options: + temperature: 0 + max_tokens: 32 + +function_deduction/cot/mixtral-8x7b-instruct: + class: evals.elsuite.function_deduction.solvers:CustomCoT + args: + cot_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: mistralai/Mixtral-8x7B-Instruct-v0.1 + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: mistralai/Mixtral-8x7B-Instruct-v0.1 + extra_options: + temperature: 0 + max_tokens: 32 + + +# CUSTOM CHAIN OF THOUGHT +function_deduction/cot/gpt-4-1106-preview: + class: evals.elsuite.function_deduction.solvers:CustomCoT + args: + cot_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-1106-preview + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-1106-preview + extra_options: + temperature: 0 + max_tokens: 32 + +function_deduction/cot/gpt-4-32k: + class: evals.elsuite.function_deduction.solvers:CustomCoT + args: + cot_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-32k + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-32k + extra_options: + temperature: 0 + max_tokens: 32 + +function_deduction/cot/gpt-3.5-turbo-16k: + class: evals.elsuite.function_deduction.solvers:CustomCoT + args: + cot_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-3.5-turbo-16k + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-3.5-turbo-16k + extra_options: + temperature: 0 + max_tokens: 32 + +function_deduction/cot/gemini-pro: + class: evals.elsuite.function_deduction.solvers:CustomCoT + args: + cot_solver: + class: evals.solvers.providers.google.gemini_solver:GeminiSolver + args: + model_name: gemini-pro + extract_solver: + class: evals.solvers.providers.google.gemini_solver:GeminiSolver + args: + model_name: gemini-pro + + +# BASE MODELS +function_deduction/gpt-4-base: + class: evals.elsuite.function_deduction.solvers:BaseModelSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 32 + +function_deduction/cot/gpt-4-base: + class: evals.elsuite.function_deduction.solvers:BaseModelCoTSolver + args: + cot_solver: + class: evals.solvers.nested.hhh_solver:HHHSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 512 + fixed_start: "Let's think step by step. " + extract_solver: + class: evals.solvers.nested.hhh_solver:HHHSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 0 + max_tokens: 32 + + +# BASELINES +function_deduction/average_baseline: + class: evals.elsuite.function_deduction.baselines:AverageBaseline + +function_deduction/full_knowledge_random: + class: evals.elsuite.function_deduction.baselines:FullKnowledge + args: + mode: random + samples_jsonl: function_deduction/data.jsonl + +function_deduction/full_knowledge_best: + class: evals.elsuite.function_deduction.baselines:FullKnowledge + args: + mode: best + samples_jsonl: function_deduction/data.jsonl