Skip to content

Commit

Permalink
Apply some best practices and guideline recommendations to code (Eleu…
Browse files Browse the repository at this point in the history
…therAI#1363)

* raise Exception, not a string

Additional info https://peps.python.org/pep-0352/#exception-hierarchy-changes
https://docs.python.org/3.8/tutorial/errors.html#raising-exceptions

* Apply PEP8 recommendation to prefer isinstance

"Object type comparisons should always use isinstance() instead of comparing types directly"
https://peps.python.org/pep-0008/

* Remove dangerous default mutable values in arguments

https://pylint.readthedocs.io/en/stable/user_guide/messages/warning/dangerous-default-value.html

* Format logging messages with fstring (not with format)

Additional info
https://pylint.readthedocs.io/en/stable/user_guide/messages/warning/logging-format-interpolation.html
There are also discussions about the speed of formatting while logging or some unintended code executions
pylint-dev/pylint#2395
https://stackoverflow.com/a/54368109
but at least one format (fstring one) will be used throughout the project

* Specify utf-8 encoding for `open` explicitly

If not specified, it may be supposed differently in different environments, OSes, and Python versions. See
https://peps.python.org/pep-0597/
https://docs.python.org/3.11/library/locale.html#locale.getencoding
https://docs.python.org/3.10/library/os.html#utf8-mode
https://pylint.readthedocs.io/en/stable/user_guide/messages/warning/unspecified-encoding.html

Helps also if some code from English language tasks is taken as inspiration for tasks in non-English languages.

* Use inline-ignoring comments to pass pre-commit instead of identity process

https://flake8.pycqa.org/en/3.0.1/user/ignoring-errors.html#in-line-ignoring-errors
https://www.flake8rules.com/rules/F841.html

flake8 comments are supported by ruff: https://docs.astral.sh/ruff/linter/#error-suppression
  • Loading branch information
LSinev authored Jan 28, 2024
1 parent 97a67d2 commit 488759d
Show file tree
Hide file tree
Showing 30 changed files with 83 additions and 74 deletions.
4 changes: 2 additions & 2 deletions lm_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
task_names = ALL_TASKS
elif args.tasks == "list":
eval_logger.info(
"Available Tasks:\n - {}".format("\n - ".join(sorted(ALL_TASKS)))
f"Available Tasks:\n - {(os.linesep + ' - ').join(sorted(ALL_TASKS))}"
)
sys.exit()
else:
Expand Down Expand Up @@ -257,7 +257,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))

if args.output_path:
output_path_file.open("w").write(dumped)
output_path_file.open("w", encoding="utf-8").write(dumped)

if args.log_samples:
for task_name, config in results["configs"].items():
Expand Down
8 changes: 2 additions & 6 deletions lm_eval/api/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,18 +152,14 @@ def get_aggregation(name):
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(
"{} not a registered aggregation metric!".format(name),
)
eval_logger.warning(f"{name} not a registered aggregation metric!")


def get_metric_aggregation(name):
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(
"{} metric is not assigned a default aggregation!".format(name),
)
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")


def is_higher_better(metric_name):
Expand Down
6 changes: 4 additions & 2 deletions lm_eval/decontamination/archiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def __init__(self, file_path: str, compression_level: int = 3) -> None:
self.cctx = zstandard.ZstdCompressor(level=compression_level)
self.compressor = self.cctx.stream_writer(self.fh)

def add_data(self, data, meta={}) -> None:
def add_data(self, data, meta=None) -> None:
if meta is None:
meta = {}
self.compressor.write(
json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
"UTF-8"
Expand Down Expand Up @@ -108,7 +110,7 @@ def __init__(self, file_path) -> None:
def read_tqdm(self, update_frequency: int = 10000):
current_file_position = 0
line_counter = 0
with open(self.file_path, "r") as fh, tqdm.tqdm(
with open(self.file_path, "r", encoding="utf-8") as fh, tqdm.tqdm(
total=os.path.getsize(self.file_path),
dynamic_ncols=True,
unit="byte",
Expand Down
2 changes: 1 addition & 1 deletion lm_eval/decontamination/decontaminate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> d
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)

info_dict_path = os.path.join(ngrams_path, "info.json")
info_dict = json.load(open(info_dict_path, "r"))
info_dict = json.load(open(info_dict_path, "r", encoding="utf-8"))
ngrams_n_size = info_dict["ngram_size"]

janitor = Janitor()
Expand Down
18 changes: 10 additions & 8 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
def simple_evaluate(
model,
model_args=None,
tasks=[],
tasks=None,
num_fewshot=None,
batch_size=None,
max_batch_size=None,
Expand Down Expand Up @@ -80,6 +80,8 @@ def simple_evaluate(
1234
) # TODO: this may affect training runs that are run with evaluation mid-run.

if tasks is None:
tasks = []
assert (
tasks != []
), "No tasks specified, or no tasks found. Please verify the task names."
Expand Down Expand Up @@ -122,7 +124,7 @@ def simple_evaluate(
task_dict = lm_eval.tasks.get_task_dict(tasks)
for task_name in task_dict.keys():
task_obj = task_dict[task_name]
if type(task_obj) == tuple:
if isinstance(task_obj, tuple):
group, task_obj = task_obj
if task_obj is None:
continue
Expand Down Expand Up @@ -242,7 +244,7 @@ def evaluate(

# get lists of each type of request
for task_name, task in task_dict.items():
if type(task) == tuple:
if isinstance(task, tuple):
group_name, task = task
task_hierarchy[group_name].append(task_name)
versions[group_name] = "N/A"
Expand Down Expand Up @@ -316,7 +318,7 @@ def evaluate(
### Run LM on inputs, get all outputs ###
# execute each type of request
for reqtype, reqs in requests.items():
eval_logger.info("Running {} requests".format(reqtype))
eval_logger.info(f"Running {reqtype} requests")
# create `K` copies of each request `req` based off `K = req.repeats`
cloned_reqs = []
for req in reqs:
Expand All @@ -339,7 +341,7 @@ def evaluate(
### Postprocess outputs ###
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
for task_name, task in task_dict.items():
if type(task) == tuple:
if isinstance(task, tuple):
group, task = task
if task is None:
continue
Expand All @@ -350,7 +352,7 @@ def evaluate(

# unpack results and sort back in order and return control to Task
for task_name, task in task_dict.items():
if type(task) == tuple:
if isinstance(task, tuple):
group, task = task
if task is None:
continue
Expand Down Expand Up @@ -401,7 +403,7 @@ def evaluate(
vals_torch = collections.defaultdict(list)
for (task_name, key, metric), items in vals.items():
numitem = 0
if type(items[0]) == tuple:
if isinstance(items[0], tuple):
numitem = len(items[0])

if isinstance(items[0], (str, list, tuple)):
Expand Down Expand Up @@ -447,7 +449,7 @@ def evaluate(
task = task_dict[task_name]
metric_key = metric + "," + key

if type(task) == tuple:
if isinstance(task, tuple):
group_name, task = task
else:
group_name = None
Expand Down
4 changes: 3 additions & 1 deletion lm_eval/filters/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def filter_set(inst):


class MapFilter(Filter):
def __init__(self, mapping_dict: dict = {}, default_value=None) -> None:
def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
"""
Initializes the MapFilter with a given mapping dictionary and default value.
Expand All @@ -37,6 +37,8 @@ def __init__(self, mapping_dict: dict = {}, default_value=None) -> None:
Example:
mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
"""
if mapping_dict is None:
mapping_dict = {}
assert isinstance(
mapping_dict, dict
), "Provided mapping_dict is not a dictionary"
Expand Down
3 changes: 1 addition & 2 deletions lm_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,7 @@ def forward_batch(batch_size):
(batch_size, max_length), device=self.device
).long()
for _ in range(5):
out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1)
out = out # Identity process so that it passes pre-commit
out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1) # noqa: F841

return batch_size

Expand Down
2 changes: 1 addition & 1 deletion lm_eval/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def apply(self, doc):

# TODO need a way to process doc_to_choice
if "doc_to_choice" in self.prompt_string:
raise "Not yet implemented to accept doc_to_choice"
raise Exception("Not yet implemented to accept doc_to_choice")

text_string = utils.apply_template(doc_to_text, doc)
target_string = utils.apply_template(doc_to_target, doc)
Expand Down
14 changes: 7 additions & 7 deletions lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def register_configurable_task(config: Dict[str, str]) -> int:
if "group" in config:
if config["group"] == config["task"]:
raise ValueError("task and group name cannot be the same")
elif type(config["group"]) == str:
elif isinstance(config["group"], str):
group_name = [config["group"]]
else:
group_name = config["group"]
Expand All @@ -57,8 +57,8 @@ def register_configurable_task(config: Dict[str, str]) -> int:
def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -> int:
group = config["group"]
all_task_list = config["task"]
config_list = [task for task in all_task_list if type(task) != str]
task_list = [task for task in all_task_list if type(task) == str]
config_list = [task for task in all_task_list if not isinstance(task, str)]
task_list = [task for task in all_task_list if isinstance(task, str)]

for task_config in config_list:

Expand All @@ -68,7 +68,7 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -
task_name = task_config["task"]
if task_name in ALL_TASKS:
task_obj = TASK_REGISTRY[task_name]
if type(task_obj) == tuple:
if isinstance(task_obj, tuple):
_, task_obj = task_obj

if task_obj is not None:
Expand Down Expand Up @@ -166,10 +166,10 @@ def include_task_folder(task_dir: str, register_task: bool = True) -> None:
)
for config in all_configs:
if register_task:
if type(config["task"]) == str:
if isinstance(config["task"], str):
register_configurable_task(config)
else:
if type(config["task"]) == list:
if isinstance(config["task"], list):
register_configurable_group(config, yaml_path)

# Log this silently and show it only when
Expand Down Expand Up @@ -243,7 +243,7 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
task_name_from_config_dict = {}
task_name_from_object_dict = {}

if type(task_name_list) != list:
if not isinstance(task_name_list, list):
task_name_list = [task_name_list]

for task_element in task_name_list:
Expand Down
4 changes: 2 additions & 2 deletions lm_eval/tasks/bbh/_generate_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def parse_args():

# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
base_yaml_name = os.path.split(args.base_yaml_path)[-1]
with open(args.base_yaml_path) as f:
with open(args.base_yaml_path, encoding="utf-8") as f:
base_yaml = yaml.full_load(f)

base_doc_to_text = "Q: {{input}}\nA:"
Expand Down Expand Up @@ -70,7 +70,7 @@ def parse_args():

file_save_path = args.save_prefix_path + f"/{task}.yaml"
utils.eval_logger.info(f"Saving yaml for subset {task} to {file_save_path}")
with open(file_save_path, "w") as yaml_file:
with open(file_save_path, "w", encoding="utf-8") as yaml_file:
yaml.dump(
yaml_dict,
yaml_file,
Expand Down
6 changes: 3 additions & 3 deletions lm_eval/tasks/belebele/_generate_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ def parse_args():

# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
base_yaml_name = os.path.split(args.base_yaml_path)[-1]
with open(args.base_yaml_path) as f:
with open(args.base_yaml_path, encoding="utf-8") as f:
base_yaml = yaml.full_load(f)

if args.cot_prompt_path is not None:
import json

with open(args.cot_prompt_path) as f:
with open(args.cot_prompt_path, encoding="utf-8") as f:
cot_file = json.load(f)

def query():
Expand All @@ -54,7 +54,7 @@ def query():

file_save_path = args.save_prefix_path + f"_{lang}.yaml"
logging.info(f"Saving yaml for subset {lang} to {file_save_path}")
with open(file_save_path, "w") as yaml_file:
with open(file_save_path, "w", encoding="utf-8") as yaml_file:
yaml.dump(
yaml_dict,
yaml_file,
Expand Down
2 changes: 1 addition & 1 deletion lm_eval/tasks/bigbench/generate_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def main() -> None:
for task in all_subtasks:
file_name = f"{task}.yaml"
try:
with open(f"{path}/{file_name}", "w") as f:
with open(f"{path}/{file_name}", "w", encoding="utf-8") as f:
f.write("# Generated by utils.py\n")
yaml.dump(
{
Expand Down
2 changes: 1 addition & 1 deletion lm_eval/tasks/blimp/generate_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main() -> None:
for task in all_subtasks:
file_name = f"{task}.yaml"
try:
with open(f"{file_name}", "w") as f:
with open(f"{file_name}", "w", encoding="utf-8") as f:
f.write("# Generated by utils.py\n")
yaml.dump(
{
Expand Down
6 changes: 3 additions & 3 deletions lm_eval/tasks/ceval/_generate_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ def parse_args():

# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
base_yaml_name = os.path.split(args.base_yaml_path)[-1]
with open(args.base_yaml_path) as f:
with open(args.base_yaml_path, encoding="utf-8") as f:
base_yaml = yaml.full_load(f)

if args.cot_prompt_path is not None:
import json

with open(args.cot_prompt_path) as f:
with open(args.cot_prompt_path, encoding="utf-8") as f:
cot_file = json.load(f)

for subject_eng, subject_zh in tqdm(SUBJECTS.items()):
Expand All @@ -107,7 +107,7 @@ def parse_args():

file_save_path = args.save_prefix_path + f"_{subject_eng}.yaml"
eval_logger.info(f"Saving yaml for subset {subject_eng} to {file_save_path}")
with open(file_save_path, "w") as yaml_file:
with open(file_save_path, "w", encoding="utf-8") as yaml_file:
yaml.dump(
yaml_dict,
yaml_file,
Expand Down
6 changes: 3 additions & 3 deletions lm_eval/tasks/cmmlu/_generate_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ def parse_args():

# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
base_yaml_name = os.path.split(args.base_yaml_path)[-1]
with open(args.base_yaml_path) as f:
with open(args.base_yaml_path, encoding="utf-8") as f:
base_yaml = yaml.full_load(f)

if args.cot_prompt_path is not None:
import json

with open(args.cot_prompt_path) as f:
with open(args.cot_prompt_path, encoding="utf-8") as f:
cot_file = json.load(f)

for subject_eng, subject_zh in tqdm(SUBJECTS.items()):
Expand All @@ -122,7 +122,7 @@ def parse_args():

file_save_path = args.save_prefix_path + f"_{subject_eng}.yaml"
eval_logger.info(f"Saving yaml for subset {subject_eng} to {file_save_path}")
with open(file_save_path, "w") as yaml_file:
with open(file_save_path, "w", encoding="utf-8") as yaml_file:
yaml.dump(
yaml_dict,
yaml_file,
Expand Down
2 changes: 1 addition & 1 deletion lm_eval/tasks/code_x_glue/code-text/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def splitPuncts(line):
def computeMaps(predictions, goldfile):
predictionMap: Dict[str, list] = {}
goldMap: Dict[str, list] = {}
gf = open(goldfile, "r")
gf = open(goldfile, "r", encoding="utf-8")

for row in predictions:
cols = row.strip().split("\t")
Expand Down
4 changes: 2 additions & 2 deletions lm_eval/tasks/csatqa/_generate_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def parse_args():

# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
base_yaml_name = os.path.split(args.base_yaml_path)[-1]
with open(args.base_yaml_path) as f:
with open(args.base_yaml_path, encoding="utf-8") as f:
base_yaml = yaml.full_load(f)

for name in tqdm(SUBSETS):
Expand All @@ -39,7 +39,7 @@ def parse_args():

file_save_path = args.save_prefix_path + f"_{name.lower()}.yaml"
eval_logger.info(f"Saving yaml for subset {name} to {file_save_path}")
with open(file_save_path, "w") as yaml_file:
with open(file_save_path, "w", encoding="utf-8") as yaml_file:
yaml.dump(
yaml_dict,
yaml_file,
Expand Down
Loading

0 comments on commit 488759d

Please sign in to comment.