Skip to content

Commit

Permalink
move argparser, model/benchmark retrieval, helper methods to core
Browse files Browse the repository at this point in the history
  • Loading branch information
mschrimpf committed Nov 5, 2023
1 parent f39b5d0 commit 87a2db4
Showing 1 changed file with 12 additions and 63 deletions.
75 changes: 12 additions & 63 deletions brainscore_language/submission/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import argparse
from typing import List, Union, Dict

from brainscore_core import Score, Benchmark
from brainscore_core.submission import UserManager, RunScoringEndpoint, DomainPlugins
from brainscore_core.submission import RunScoringEndpoint, DomainPlugins
from brainscore_core.submission.endpoints import make_argparser, retrieve_models_and_benchmarks, get_user_id, \
send_email_to_submitter as send_email_to_submitter_core
from brainscore_language import load_model, load_benchmark, score
from brainscore_language.submission import config

Expand All @@ -22,81 +23,29 @@ def score(self, model_identifier: str, benchmark_identifier: str) -> Score:
run_scoring_endpoint = RunScoringEndpoint(language_plugins, db_secret=config.get_database_secret())


def send_email_to_submitter(uid: int, domain: str, pr_number: str,
mail_username: str, mail_password: str):
""" Send submitter an email if their web-submitted PR fails. """
subject = "Brain-Score submission failed"
body = f"Your Brain-Score submission did not pass checks. Please review the test results and update the PR at https://github.com/brain-score/{domain}/pull/{pr_number} or send in an updated submission via the website."
user_manager = UserManager(db_secret=config.get_database_secret())
return user_manager.send_user_email(uid, body, mail_username, mail_password)


def get_user_id(email: str) -> int:
user_manager = UserManager(db_secret=config.get_database_secret())
user_id = user_manager.get_uid(email)
return user_id


def _get_ids(args_dict: Dict[str, Union[str, List]], key: str) -> Union[List, str, None]:
return args_dict[key] if key in args_dict else None


def run_scoring(args_dict: Dict[str, Union[str, List]]):
""" prepares parameters for the `run_scoring_endpoint`. """
new_models = _get_ids(args_dict, 'new_models')
new_benchmarks = _get_ids(args_dict, 'new_benchmarks')

if args_dict['specified_only']:
assert len(new_models) > 0, "No models specified"
assert len(new_benchmarks) > 0, "No benchmarks specified"
models = new_models
benchmarks = new_benchmarks
else:
if new_models and new_benchmarks:
models = RunScoringEndpoint.ALL_PUBLIC
benchmarks = RunScoringEndpoint.ALL_PUBLIC
elif new_benchmarks:
models = RunScoringEndpoint.ALL_PUBLIC
benchmarks = new_benchmarks
elif new_models:
models = new_models
benchmarks = RunScoringEndpoint.ALL_PUBLIC
benchmarks, models = retrieve_models_and_benchmarks(args_dict)

run_scoring_endpoint(domain="language", jenkins_id=args_dict["jenkins_id"],
models=models, benchmarks=benchmarks, user_id=args_dict["user_id"],
model_type="artificialsubject", public=args_dict["public"],
competition=args_dict["competition"])


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('jenkins_id', type=int,
help='The id of the current jenkins run')
parser.add_argument('public', type=bool, nargs='?', default=True,
help='Public (or private) submission?')
parser.add_argument('--competition', type=str, nargs='?', default=None,
help='Name of competition for which submission is being scored')
parser.add_argument('--user_id', type=int, nargs='?', default=None,
help='ID of submitting user in the postgres DB')
parser.add_argument('--author_email', type=str, nargs='?', default=None,
help='email associated with PR author GitHub username')
parser.add_argument('--specified_only', type=bool, nargs='?', default=False,
help='Only score the plugins specified by new_models and new_benchmarks')
parser.add_argument('--new_models', type=str, nargs='*', default=None,
help='The identifiers of newly submitted models to score on all benchmarks')
parser.add_argument('--new_benchmarks', type=str, nargs='*', default=None,
help='The identifiers of newly submitted benchmarks on which to score all models')
args, remaining_args = parser.parse_known_args()

return args
def send_email_to_submitter(uid: int, domain: str, pr_number: str,
mail_username: str, mail_password: str):
send_email_to_submitter_core(uid=uid, domain=domain, pr_number=pr_number,
db_secret=config.get_database_secret(),
mail_username=mail_username, mail_password=mail_password)


if __name__ == '__main__':
args = parse_args()
parser = make_argparser()
args, remaining_args = parser.parse_known_args()
args_dict = vars(args)

if 'user_id' not in args_dict or args_dict['user_id'] is None:
user_id = get_user_id(args_dict['author_email'])
user_id = get_user_id(args_dict['author_email'], db_secret=config.get_database_secret())
args_dict['user_id'] = user_id

run_scoring(args_dict)

0 comments on commit 87a2db4

Please sign in to comment.