Skip to content

Commit

Permalink
Switch Linting to ruff (EleutherAI#1166)
Browse files Browse the repository at this point in the history
* add ruff and isort. remove black and flake8

* remove unnecessary dependencies

* remove dependency from table

* change order

* ran ruff

* check 3.9

* exclude evaluator

* update CI workflow

* use ruff config in pyproject.toml

* test

* add isort rules to ruff

* sort imports

* import `make_table`

* try stages for no-commit-to-branch

* turn on mypy for pre-commit

* test

* test

* test

* change no-commit-to-branch to default

* nits

* fixed dependency
  • Loading branch information
baberabb authored Dec 20, 2023
1 parent 21d4ae9 commit 65b8761
Show file tree
Hide file tree
Showing 84 changed files with 389 additions and 446 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/new_tasks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run: |
python -m pip install --upgrade pip
pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[dev]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
Expand Down
29 changes: 11 additions & 18 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,22 @@ jobs:
linter:
name: Linters
runs-on: ubuntu-latest
timeout-minutes: 20
timeout-minutes: 5

steps:
- name: Checkout Code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Set up Python 3.8
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: 3.8
cache: pip
cache-dependency-path: setup.py
- name: Install dependencies
run: pip install -e '.[linting,testing]' --extra-index-url https://download.pytorch.org/whl/cpu ; export SKIP=no-commit-to-branch # env var deactivates --no-commit-to-branch
cache-dependency-path: pyproject.toml
- name: Pre-Commit
env:
SKIP: "no-commit-to-branch,mypy"

uses: pre-commit/[email protected]
- name: Lint with pylint
run: python -m pylint --disable=all -e W0311 --jobs=0 --indent-string=' ' **/*.py
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=F,E9,E71,E72,E501,E112,E113,W6 --extend-ignore=F541 --show-source --statistics --exit-zero
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# # mypy turned off for now
# - name: Lint with mypy
# run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable
Expand All @@ -53,17 +46,17 @@ jobs:
timeout-minutes: 30
steps:
- name: Checkout Code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: setup.py
cache-dependency-path: pyproject.toml
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[dev,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
Expand Down
16 changes: 9 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ repos:
args: [--remove]
- id: mixed-line-ending
args: [--fix=lf]
- repo: https://github.com/pycqa/flake8
rev: 3.7.9
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.8
hooks:
- id: flake8
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
# Run the linter.
- id: ruff
args:
- --fix
# Run the formatter.
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
hooks:
Expand Down
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@ pip install -e .
We also provide a number of optional dependencies for extended functionality. Extras can be installed via `pip install -e ".[NAME]"`

| Name | Use |
| ------------- | ------------------------------------- |
|---------------|---------------------------------------|
| anthropic | For using Anthropic's models |
| dev | You probably don't want to use this |
| gptq | For loading models with GPTQ |
| testing | You probably don't want to use this |
| dev | You probably don't want to use this |
| multilingual | For multilingual tokenizers |
| openai | For using OpenAI's models |
| promptsource | For using PromtSource prompts |
Expand Down
19 changes: 10 additions & 9 deletions lm_eval/__main__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import argparse
import json
import logging
import os
import re
import sys
import json
import logging
import argparse
import numpy as np

from pathlib import Path
from typing import Union

import numpy as np

from lm_eval import evaluator, utils
from lm_eval.tasks import initialize_tasks, include_path
from lm_eval.api.registry import ALL_TASKS
from lm_eval.tasks import include_path, initialize_tasks
from lm_eval.utils import make_table


def _handle_non_serializable(o):
Expand Down Expand Up @@ -170,7 +171,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
task_names = ALL_TASKS
elif args.tasks == "list":
eval_logger.info(
"Available Tasks:\n - {}".format(f"\n - ".join(sorted(ALL_TASKS)))
"Available Tasks:\n - {}".format("\n - ".join(sorted(ALL_TASKS)))
)
sys.exit()
else:
Expand Down Expand Up @@ -271,9 +272,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
)
print(evaluator.make_table(results))
print(make_table(results))
if "groups" in results:
print(evaluator.make_table(results, "groups"))
print(make_table(results, "groups"))


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions lm_eval/api/filter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass
from typing import List

from lm_eval.api.instance import Instance
from datasets import Dataset

from lm_eval.api.instance import Instance


class Filter:
"""
Expand Down Expand Up @@ -42,7 +43,6 @@ class FilterEnsemble:
filters: List[Filter]

def apply(self, instances: List[Instance], docs: List[Dataset]) -> None:

resps = [
inst.resps for inst in instances
] # operate just on the model responses
Expand Down
9 changes: 5 additions & 4 deletions lm_eval/api/metrics.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import logging
import math
import random
from collections.abc import Iterable

import evaluate
import numpy as np
import sacrebleu
import sklearn.metrics
import random
import evaluate

from lm_eval.api.registry import register_metric, register_aggregation
from lm_eval.api.registry import register_aggregation, register_metric

import logging

eval_logger = logging.getLogger("lm-eval")


# Register Aggregations First
@register_aggregation("mean")
def mean(arr):
Expand Down
10 changes: 4 additions & 6 deletions lm_eval/api/model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import abc
import hashlib
import json
import logging
import os
from typing import List, Optional, Tuple, Type, TypeVar

import torch
from typing import Union, List, Tuple, Optional, Type, TypeVar
from sqlitedict import SqliteDict
import json
import hashlib

from tqdm import tqdm

from lm_eval import utils

import logging

eval_logger = logging.getLogger("lm-eval")

Expand Down
10 changes: 3 additions & 7 deletions lm_eval/api/registry.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import logging

import evaluate

from lm_eval.api.model import LM

import logging

eval_logger = logging.getLogger("lm-eval")

Expand Down Expand Up @@ -91,7 +92,6 @@ def decorate(fn):
def register_metric(**args):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):

assert "metric" in args
name = args["metric"]

Expand All @@ -100,7 +100,6 @@ def decorate(fn):
("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
("aggregation", METRIC_AGGREGATION_REGISTRY),
]:

if key in args:
value = args[key]
assert (
Expand All @@ -120,7 +119,6 @@ def decorate(fn):


def get_metric(name, hf_evaluate_metric=False):

if not hf_evaluate_metric:
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
Expand Down Expand Up @@ -151,7 +149,6 @@ def decorate(fn):


def get_aggregation(name):

try:
return AGGREGATION_REGISTRY[name]
except KeyError:
Expand All @@ -161,7 +158,6 @@ def get_aggregation(name):


def get_metric_aggregation(name):

try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
Expand Down
10 changes: 5 additions & 5 deletions lm_eval/api/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ def get_context(self, doc, num_fewshot):
self.doc_to_text(doc)
if (
self.config.doc_to_choice is None
or type(self.doc_to_text(doc)) is str
or isinstance(self.doc_to_text(doc), str)
)
else self.doc_to_choice(doc)[self.doc_to_text(doc)]
)
+ self.target_delimiter
+ (
str(self.doc_to_target(doc)[0])
if type(self.doc_to_target(doc)) is list
if isinstance(self.doc_to_target(doc), list)
else self.doc_to_target(doc)
if (
self.config.doc_to_choice is None
or type(self.doc_to_target(doc)) is str
or isinstance(self.doc_to_target(doc), str)
)
else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
)
Expand All @@ -77,8 +77,8 @@ def sample(self, n) -> None:
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
"""
assert n <= len(
self.docs
assert (
n <= len(self.docs)
), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
return self.docs[:n]

Expand Down
Loading

0 comments on commit 65b8761

Please sign in to comment.