Skip to content

Commit

Permalink
More updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Kush Dave Jain committed Jan 20, 2025
1 parent 3355bae commit 9f9a65c
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 57 deletions.
35 changes: 15 additions & 20 deletions evaluation/benchmarks/testgeneval/eval_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def get_config(instance: pd.Series) -> AppConfig:
use_host_network=False,
timeout=1800,
api_key=os.environ.get('ALLHANDS_API_KEY'),
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
remote_runtime_api_url=os.environ.get(
'SANDBOX_REMOTE_RUNTIME_API_URL', 'http://localhost:8000'
),
),
workspace_base=None,
workspace_mount_path=None,
Expand Down Expand Up @@ -104,7 +106,7 @@ def compute_lexical_metrics(pred_suite, gold_suite):

def run_command(runtime, command, timeout=600):
action = CmdRunAction(command=command)
action.timeout = timeout
action.set_hard_timeout(timeout)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -113,10 +115,8 @@ def run_command(runtime, command, timeout=600):


def run_tests(runtime, instance, test_script, log_file='/tmp/test_output.log'):
action = CmdRunAction(
command=f'bash {test_script} > {log_file} 2>&1 & echo $!', keep_prompt=False
)
action.timeout = 60
action = CmdRunAction(command=f'bash {test_script} > {log_file} 2>&1 & echo $!')
action.set_hard_timeout(60)
obs = runtime.run_action(action)

assert isinstance(obs, CmdOutputObservation), 'Failed to start test script.'
Expand All @@ -132,9 +132,7 @@ def run_tests(runtime, instance, test_script, log_file='/tmp/test_output.log'):
instance['test_result']['report']['test_timeout'] = True
break

check_action = CmdRunAction(
command=f'ps -p {pid} > /dev/null; echo $?', keep_prompt=False
)
check_action = CmdRunAction(command=f'ps -p {pid} > /dev/null; echo $?')
check_obs = runtime.run_action(check_action)
if (
isinstance(check_obs, CmdOutputObservation)
Expand All @@ -144,8 +142,8 @@ def run_tests(runtime, instance, test_script, log_file='/tmp/test_output.log'):
break
time.sleep(30)

test_action = CmdRunAction(command=f'cat {log_file}', keep_prompt=False)
test_action.timeout = 300
test_action = CmdRunAction(command=f'cat {log_file}')
test_action.set_hard_timeout(300)
test_obs = runtime.run_action(test_action)
assert isinstance(test_obs, CmdOutputObservation), 'Failed to retrieve test output.'
return test_obs.exit_code, test_obs.content, elapsed_time
Expand All @@ -154,10 +152,8 @@ def run_tests(runtime, instance, test_script, log_file='/tmp/test_output.log'):
def run_mutation_testing(
runtime, instance, mutation_script, log_file='/tmp/mutation_output.log'
):
action = CmdRunAction(
command=f'bash {mutation_script} > {log_file} 2>&1 & echo $!', keep_prompt=False
)
action.timeout = 60
action = CmdRunAction(command=f'bash {mutation_script} > {log_file} 2>&1 & echo $!')
action.set_hard_timeout(60)
obs = runtime.run_action(action)

assert isinstance(obs, CmdOutputObservation), 'Failed to start test script.'
Expand All @@ -173,9 +169,7 @@ def run_mutation_testing(
instance['test_result']['report']['mutation_timeout'] = True
break

check_action = CmdRunAction(
command=f'ps -p {pid} > /dev/null; echo $?', keep_prompt=False
)
check_action = CmdRunAction(command=f'ps -p {pid} > /dev/null; echo $?')
check_obs = runtime.run_action(check_action)
if (
isinstance(check_obs, CmdOutputObservation)
Expand All @@ -186,8 +180,8 @@ def run_mutation_testing(
time.sleep(30)

assert isinstance(obs, CmdOutputObservation), 'Failed to run mutation script.'
mutation_action = CmdRunAction(command=f'cat {log_file}', keep_prompt=False)
mutation_action.timeout = 300
mutation_action = CmdRunAction(command=f'cat {log_file}')
mutation_action.set_hard_timeout(300)
mutation_obs = runtime.run_action(mutation_action)
assert isinstance(
mutation_obs, CmdOutputObservation
Expand Down Expand Up @@ -308,6 +302,7 @@ def process_instance(
id = instance.id
logger.info(f'Starting evaluation for instance {id}.')

instance['test_result']['id'] = id
instance['test_result']['report'] = {
'test_output': '',
'coverage_output': '',
Expand Down
8 changes: 7 additions & 1 deletion evaluation/benchmarks/testgeneval/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
Finally proceed to writing a test suite at {test_file} that tests {code_file}.
IMPORT {code_file} in the test suite you write, DO NOT import the general library.
If you are unable to run coverage report, FIX YOUR IMPORTS to use the same style as other test files.
You should NOT modify any existing test case files. You SHOULD add new test in a NEW file to reproduce the issue.
You should NEVER use web browsing or any other web-based tools.
Expand All @@ -57,7 +61,9 @@
Then run coverage report -m --include {code_file} to see how well your test suite covers the code.
Focus on generating passing tests first, then on improving coverage. REMOVE failing tests.
Focus on generating passing tests first, then on improving coverage.
Try to fix failing tests 3 times, then REMOVE failing tests and add NEW tests.
When you are trying to improve coverage pick a part of the code that is not covered (indicated by lines on coverage report), examine the code and then
try to generate a test for it. Feel free to use a code interpreter to understand the input output behavior. ONLY add tests
Expand Down
65 changes: 29 additions & 36 deletions evaluation/benchmarks/testgeneval/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import openhands.agenthub
from evaluation.benchmarks.testgeneval.constants import MAP_REPO_VERSION_TO_SPECS
from evaluation.benchmarks.testgeneval.prompt import (
CODEACT_SWE_TESTGEN_PROMPT,
CODEACT_TESTGEN_PROMPT,
)
from evaluation.benchmarks.testgeneval.utils import (
Expand All @@ -24,6 +23,8 @@
EvalOutput,
assert_and_raise,
codeact_user_response,
get_metrics,
is_fatal_evaluation_error,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
Expand All @@ -50,7 +51,6 @@

AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
'CodeActAgent': codeact_user_response,
'CodeActSWEAgent': codeact_user_response,
}


Expand All @@ -66,7 +66,7 @@ def _get_swebench_workspace_dir_name(instance: pd.Series) -> str:


def get_instruction(instance: pd.Series, metadata: EvalMetadata):
workspace_dir_name = _get_swebench_workspace_dir_name(instance)
# workspace_dir_name = _get_swebench_workspace_dir_name(instance)
# Prepare instruction
coverage_command = ' '.join(
[
Expand All @@ -77,20 +77,12 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
]
)

if metadata.agent_class == 'CodeActSWEAgent':
instruction = CODEACT_SWE_TESTGEN_PROMPT.format(
workspace_dir_name=workspace_dir_name,
test_file=instance.test_file,
code_file=instance.code_file,
coverage_command=coverage_command,
)
else:
# Testing general agents
instruction = CODEACT_TESTGEN_PROMPT.format(
code_file=instance.code_file,
test_file=instance.test_file,
coverage_command=coverage_command,
)
# Testing general agents
instruction = CODEACT_TESTGEN_PROMPT.format(
code_file=instance.code_file,
test_file=instance.test_file,
coverage_command=coverage_command,
)

if RUN_WITH_BROWSING:
instruction += (
Expand Down Expand Up @@ -141,7 +133,9 @@ def get_config(
# Add platform to the sandbox config to solve issue 4401
platform='linux/amd64',
api_key=os.environ.get('ALLHANDS_API_KEY', None),
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
remote_runtime_api_url=os.environ.get(
'SANDBOX_REMOTE_RUNTIME_API_URL', 'http://localhost:8000'
),
keep_runtime_alive=False,
remote_runtime_init_timeout=3600,
),
Expand All @@ -158,6 +152,7 @@ def get_config(
codeact_enable_jupyter=False,
codeact_enable_browsing=RUN_WITH_BROWSING,
codeact_enable_llm_editor=False,
condenser=metadata.condenser_config,
)
config.set_agent_config(agent_config)
return config
Expand All @@ -181,7 +176,7 @@ def initialize_runtime(
action = CmdRunAction(
command=f"""echo 'export SWE_INSTANCE_ID={instance['instance_id']}' >> ~/.bashrc && echo 'export PIP_CACHE_DIR=~/.cache/pip' >> ~/.bashrc && echo "alias git='git --no-pager'" >> ~/.bashrc"""
)
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -190,7 +185,7 @@ def initialize_runtime(
)

action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """)
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -201,7 +196,7 @@ def initialize_runtime(

# inject the instance info
action = CmdRunAction(command='mkdir -p /swe_util/eval_data/instances')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand Down Expand Up @@ -232,14 +227,14 @@ def initialize_runtime(
'/swe_util/',
)
action = CmdRunAction(command='cat ~/.bashrc')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(obs.exit_code == 0, f'Failed to cat ~/.bashrc: {str(obs)}')

action = CmdRunAction(command='source ~/.bashrc')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -248,7 +243,7 @@ def initialize_runtime(
assert_and_raise(obs.exit_code == 0, f'Failed to source ~/.bashrc: {str(obs)}')

action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh')
action.timeout = 3600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -258,7 +253,7 @@ def initialize_runtime(
)

action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -268,7 +263,7 @@ def initialize_runtime(
)

action = CmdRunAction(command='git reset --hard')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -277,7 +272,7 @@ def initialize_runtime(
action = CmdRunAction(
command='for remote_name in $(git remote); do git remote remove "${remote_name}"; done'
)
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand Down Expand Up @@ -305,7 +300,7 @@ def complete_runtime(
workspace_dir_name = _get_swebench_workspace_dir_name(instance)

action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -315,7 +310,7 @@ def complete_runtime(
)

action = CmdRunAction(command=f'cat {instance.test_file}')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -327,7 +322,7 @@ def complete_runtime(
test_suite = obs.content.strip()

# action = CmdRunAction(command='git add -A')
# action.timeout = 600
# action.set_hard_timeout(600)
# logger.info(action, extra={'msg_type': 'ACTION'})
# obs = runtime.run_action(action)
# logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand Down Expand Up @@ -376,11 +371,7 @@ def process_instance(
)

# if fatal error, throw EvalError to trigger re-run
if (
state.last_error
and 'fatal error during agent execution' in state.last_error
and 'stuck in a loop' not in state.last_error
):
if is_fatal_evaluation_error(state.last_error):
raise EvalException('Fatal error detected: ' + state.last_error)

# ======= THIS IS SWE-Bench specific =======
Expand All @@ -406,7 +397,7 @@ def process_instance(
raise ValueError('State should not be None.')

histories = [event_to_dict(event) for event in state.history]
metrics = state.metrics.get() if state.metrics else None
metrics = get_metrics(state)

# Save the output
output = EvalOutput(
Expand Down Expand Up @@ -464,6 +455,8 @@ def filter_dataset(dataset: pd.DataFrame, filter_column: str) -> pd.DataFrame:
if args.llm_config:
llm_config = get_llm_config_arg(args.llm_config)
llm_config.log_completions = True
# modify_params must be False for evaluation purpose, for reproducibility and accurancy of results
llm_config.modify_params = False

if llm_config is None:
raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
Expand Down

0 comments on commit 9f9a65c

Please sign in to comment.