Skip to content

Commit

Permalink
chore(client): enable flake8 bugbear for python code (#2897)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut authored Oct 25, 2023
1 parent 0728ec6 commit 8e1edb8
Show file tree
Hide file tree
Showing 36 changed files with 85 additions and 72 deletions.
3 changes: 3 additions & 0 deletions client/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"B", # flake8-bugbear
]
ignore = [
"E501", # line too long, handled by black
Expand All @@ -68,6 +69,7 @@ ignore = [
"E731", # Do not assign a `lambda` expression, use a `def`
"E741", # ambiguous variable name
"W605", # invalid escape sequence '\#'
"B018", # Found useless expression.
]
exclude = [
".eggs",
Expand Down Expand Up @@ -99,6 +101,7 @@ exclude = [
"starwhale/base/data_type.py" = ["E721"]
"starwhale/core/dataset/store.py" = ["E721"]
"tests/sdk/test_dataset_sdk.py" = ["E721"]
"starwhale/api/_impl/job/handler.py" = ["B009"]

#TODO: replace isort with ruff.isort, currently ruff doesn't support all isort options
[tool.isort]
Expand Down
2 changes: 1 addition & 1 deletion client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,7 +1212,7 @@ def decompress(self, source: Path) -> Iterator[Path]:

# get all the compressors in this module
compressors: Dict[str, Compressor] = {}
for name, obj in inspect.getmembers(sys.modules[__name__]):
for _, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isclass(obj) and issubclass(obj, Compressor) and obj != Compressor:
compressors[obj.__name__] = obj()

Expand Down
2 changes: 1 addition & 1 deletion client/starwhale/api/_impl/dataset/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def _get_data_loader(

return _loader

@lru_cache(maxsize=32)
@lru_cache(maxsize=32) # noqa: B019
def _get_datastore_revision(self, uri: Resource) -> DatastoreRevision:
if uri.typ != ResourceType.dataset:
raise NoSupportError(
Expand Down
4 changes: 2 additions & 2 deletions client/starwhale/api/_impl/job/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def wrapper(*args: t.Any, **kwargs: t.Any) -> None:
if (
idx == 0
and args
and hasattr(getattr(args[0], func.__name__), "__call__")
and callable(getattr(args[0], func.__name__))
): # if the first argument has a function with the same name it is considered as self
continue
required = _p.default is inspect._empty or (
Expand All @@ -240,7 +240,7 @@ def wrapper(*args: t.Any, **kwargs: t.Any) -> None:
if (
idx == 0
and args
and hasattr(getattr(args[0], func.__name__), "__call__")
and callable(getattr(args[0], func.__name__))
):
continue
parsed_args = {
Expand Down
2 changes: 1 addition & 1 deletion client/starwhale/api/_impl/track/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def _log_artifacts(
return

_data = flatten(data, extract_sequence=True)
for k, v in _data.items():
for v in _data.values():
# TODO: support more artifacts type
if not isinstance(v, BaseArtifact):
raise NoSupportError(f"v({v}:{type(v)}) not support for artifacts")
Expand Down
6 changes: 4 additions & 2 deletions client/starwhale/base/blob/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import shutil
import typing as t
import fnmatch
from abc import ABC
from abc import ABC, abstractmethod
from pathlib import Path

from starwhale.utils import console, pretty_bytes
Expand All @@ -18,9 +18,11 @@


class ObjectStore(ABC):
@abstractmethod
def get(self, key: str) -> t.Any:
raise NotImplementedError

@abstractmethod
def put(self, path: PathLike, key: t.Optional[str]) -> t.Any:
raise NotImplementedError

Expand Down Expand Up @@ -114,7 +116,7 @@ def _search_all_py_envs(cls, path: PathLike) -> t.List[Path]:
Returns: a list of virtual environments or conda environments
"""
envs = []
for root, dirs, files in os.walk(path):
for root, _, _ in os.walk(path):
if cls._check_if_is_py_env(root):
envs.append(Path(root))
return envs
Expand Down
4 changes: 2 additions & 2 deletions client/starwhale/base/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def __eq__(self, other: t.Any) -> bool:
def get_runtime_context(cls) -> Context:
try:
val: Context = cls._context_holder.value # type: ignore
except AttributeError:
except AttributeError as e:
raise RuntimeError(
"Starwhale does not set Context yet, please check if the get_runtime_context function is used at the right time."
)
) from e

if not isinstance(val, Context):
raise RuntimeError(
Expand Down
8 changes: 4 additions & 4 deletions client/starwhale/base/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,10 @@ def _do_validate(self) -> None:
def to_pil(self) -> t.Any:
try:
from PIL import Image as PILImage
except ImportError: # pragma: no cover
except ImportError as e: # pragma: no cover
raise MissingDependencyError(
"pillow is required to convert Starwhale Image to Pillow Image, please install pillow with 'pip install pillow' or 'pip install starwhale[image]'."
)
) from e

return PILImage.open(io.BytesIO(self.to_bytes()))

Expand Down Expand Up @@ -369,10 +369,10 @@ def _do_validate(self) -> None:
def to_numpy(self) -> numpy.ndarray:
try:
import soundfile
except ImportError: # pragma: no cover
except ImportError as e: # pragma: no cover
raise MissingDependencyError(
"soundfile is required to convert Starwhale Auto to numpy ndarray, please install soundfile with 'pip install soundfile' or 'pip install starwhale[audio]'."
)
) from e

array, _ = soundfile.read(
io.BytesIO(self.to_bytes()), dtype=self.dtype.name, always_2d=True
Expand Down
2 changes: 1 addition & 1 deletion client/starwhale/base/scheduler/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,6 @@ def execute(self) -> TaskResult:
f"finish {self.context}, status:{self.status}, error:{self.exception}"
)
loop.close()
return TaskResult(
return TaskResult( # noqa: B012
id=self.index, status=self.status, exception=self.exception
)
4 changes: 2 additions & 2 deletions client/starwhale/cli/assistance/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ async def write_command_file(
count = f.write(offset, data, 1)
return {"count": count}
except (ClosedError, InvalidOffsetError) as e:
raise HTTPException(400, str(e))
raise HTTPException(400, str(e)) from e


@app.post("/session/{session_id}/command/{command_id}/file/{filename}/closeRead")
Expand Down Expand Up @@ -267,7 +267,7 @@ def read_command_file(
try:
data = f.read(offset, 1)
except (ClosedError, InvalidOffsetError) as e:
raise HTTPException(400, str(e))
raise HTTPException(400, str(e)) from e
if data is None:
return Response(
headers={"x-wait-for-data": ""},
Expand Down
10 changes: 7 additions & 3 deletions client/starwhale/cli/assistance/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,16 @@ def stop(self) -> None:
self.stopped = True

def run_until_success(
self, f: Callable[[], Any], ignore_stopped: bool = False
self,
f: Callable,
ignore_stopped: bool = False,
*f_args: Any,
**f_kwargs: Any,
) -> Any:
backoff = 0.1
while not self.stopped or ignore_stopped:
try:
return f()
return f(*f_args, **f_kwargs)
except UnrecoverableError:
raise
except Exception:
Expand Down Expand Up @@ -314,7 +318,7 @@ def run(self) -> None:
return
d = data
while not self.stopped and len(d) > 0:
count = self.run_until_success(lambda: self._write_to_broker(d))
count = self.run_until_success(self._write_to_broker, False, d)
self.offset += count
d = d[count:]
except Exception as e:
Expand Down
10 changes: 5 additions & 5 deletions client/starwhale/cli/board/widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ def reload(self) -> None:
data = self._orderby.sort(self.data)
for idx, item in enumerate(data):

def try_render(col: Column) -> t.Any:
if col.render:
return col.render(idx, item)
return get_field(item, col.key)
def try_render(col_: Column, idx_: int, item_: t.Any) -> t.Any:
if col_.render:
return col_.render(idx_, item_)
return get_field(item_, col_.key)

self.table.add_row(*[try_render(i) for i in self.render_fn])
self.table.add_row(*[try_render(i, idx, item) for i in self.render_fn])

self.highlight_row(self.cursor_line)
self.refresh()
Expand Down
2 changes: 1 addition & 1 deletion client/starwhale/core/dataset/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def _make_file(
_BFType = t.TypeVar("_BFType", bound="BaseBufferedFileLike")


class BaseBufferedFileLike(metaclass=ABCMeta):
class BaseBufferedFileLike:
def __init__(self, buffer_size: int) -> None:
self._read_buffer = BytesBuffer(buffer_size)

Expand Down
6 changes: 4 additions & 2 deletions client/starwhale/core/model/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ async def _send_request(
params: t.Dict[str, t.Any] | None = None,
content: str | bytes | None = None,
json: t.Any | None = None,
headers: t.Dict[str, str] = {},
headers: t.Dict[str, str] | None = None,
) -> httpx.Response | None:
headers = headers or {}
resp = await _httpx_client.get().request(
method,
url,
Expand Down Expand Up @@ -120,9 +121,10 @@ async def _http_request(
params: t.Dict[str, t.Any] | None = None,
content: str | bytes | None = None,
json: t.Any | None = None,
headers: t.Dict[str, str] = {},
headers: t.Dict[str, str] | None = None,
replace: bool = True,
) -> httpx.Response:
headers = headers or {}
instance = _instance.get()
headers = dict(headers)
if url is None:
Expand Down
4 changes: 3 additions & 1 deletion client/starwhale/integrations/huggingface/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
try:
import datasets as hf_datasets
except ImportError: # pragma: no cover
raise ImportError("Please install huggingface/datasets with `pip install datasets`")
raise ImportError(
"Please install huggingface/datasets with `pip install datasets`"
) from None

from starwhale.utils import console
from starwhale.base.data_type import Audio, Image, MIMEType
Expand Down
7 changes: 5 additions & 2 deletions client/starwhale/utils/http.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import typing as t
from http import HTTPStatus
from functools import wraps
Expand All @@ -13,8 +15,9 @@ def wrap_sw_error_resp(
header: str,
use_raise: bool = False,
silent: bool = False,
ignore_status_codes: t.List[int] = [],
ignore_status_codes: t.List[int] | None = None,
) -> None:
ignore_status_codes = ignore_status_codes or []
if silent:
_print: t.Callable = lambda x: x
else:
Expand All @@ -35,7 +38,7 @@ def wrap_sw_error_resp(
msg += f":dragon: error message: {_resp['message']}"
finally:
if r.status_code in ignore_status_codes:
return
return # noqa: B012

_print(Panel.fit(msg, title=":space_invader: error details")) # type: ignore

Expand Down
4 changes: 2 additions & 2 deletions client/tests/base/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_runtime_copy_c2l(self, rm: Mocker, *args: t.Any) -> None:
st = StandaloneTag(Resource("pytorch", typ=ResourceType.runtime))
assert set(st.list()) == {"latest", "t1", "t2", "t3", "v0"}

with self.assertRaises(Exception):
with self.assertRaisesRegex(KeyError, "local"):
BundleCopy(
src_uri=cloud_uri,
dest_uri="local/project/self/pytorch-new-alias",
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def test_dataset_copy_c2l(self, rm: Mocker, *args: MagicMock) -> None:
assert swds_manifest_path.exists()
assert swds_manifest_path.is_file()

with self.assertRaises(Exception):
with self.assertRaisesRegex(KeyError, "local"):
DatasetCopy(
src_uri=cloud_uri,
dest_uri="local/project/self/mnist-new-alias",
Expand Down
2 changes: 1 addition & 1 deletion client/tests/base/uri/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_instance(self, load_conf: MagicMock) -> None:
Instance(uri="https://foo.com")

# use both alias and url
with self.assertRaises(Exception):
with self.assertRaisesRegex(Exception, "alias and uri can not both set"):
Instance(instance_alias="foo", uri="https://bar.com")

# path
Expand Down
2 changes: 1 addition & 1 deletion client/tests/base/uri/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_parse_from_full_uri(self, rm: Mocker, load_conf: MagicMock) -> None:
"foo/project/myproject/dataset/mnist": "myproject",
}

for uri, project in tests.items():
for uri, _ in tests.items():
p = Project.parse_from_full_uri(uri, ignore_rc_type=False)
assert p.id == "1"
assert p.name == "myproject"
Expand Down
6 changes: 4 additions & 2 deletions client/tests/base/uri/test_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_resource_base(self) -> None:
assert r.version == "bar"
assert r.full_uri == "local/project/self/dataset/foo/version/bar"

with self.assertRaises(Exception):
with self.assertRaisesRegex(Exception, "invalid uri without version field"):
Resource(
uri="mnist/foo/bar",
typ=ResourceType.dataset,
Expand Down Expand Up @@ -244,7 +244,9 @@ def response_of_get(*args, **kwargs):
expect.version = "latest of the version"
assert expect == Resource(url)

with self.assertRaises(Exception):
with self.assertRaisesRegex(
NoMatchException, "Can not find the exact match item"
):
Resource("https://foo.com/projects/1/model") # model missing the tail 's'

@Mocker()
Expand Down
2 changes: 1 addition & 1 deletion client/tests/cli/assistance/test_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_garbage_collection(self) -> None:
},
resp.json(),
)
for i in range(5):
for _ in range(5):
self._read(0)
self._write(b"", 0)
time.sleep(0.5)
Expand Down
2 changes: 1 addition & 1 deletion client/tests/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_no_support_type_for_build_handler(self) -> None:
dataset_uri = Resource(name, typ=ResourceType.dataset)

def _iter_rows() -> t.Generator:
for i in range(0, 5):
for _ in range(0, 5):
yield {"a": {1: "a", b"b": "b"}}

sd = StandaloneDataset(dataset_uri)
Expand Down
2 changes: 1 addition & 1 deletion client/tests/sdk/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ def test_standalone_tdsc(
assert tdsc._todo_queue.qsize() == 11

rt_tasks = []
for i in range(0, 11):
for _ in range(0, 11):
task = tdsc._todo_queue.get()
rt_tasks.append(task)
tdsc._todo_queue.put(task)
Expand Down
2 changes: 1 addition & 1 deletion client/tests/sdk/test_job_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ def report_handler(self): ...
"mock_user_module:MockReport.report_handler",
} == {r.name for r in results}
assert all([r.status == "success" for r in results])
assert {1, 1, 10, 1} == {len(r.task_results) for r in results}
assert {1, 10} == {len(r.task_results) for r in results}
for r in results:
if r.name == "mock_user_module:predict_handler":
assert {i for i in range(10)} == {t.id for t in r.task_results}
Expand Down
2 changes: 1 addition & 1 deletion client/tests/utils/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_http_retry(self, request_mock: Mocker) -> None:

assert self._do_urllib_raise.retry.statistics["attempt_number"] == 6

with self.assertRaises(Exception):
with self.assertRaisesRegex(Exception, "dummy"):
self._do_raise()
assert self._do_raise.retry.statistics["attempt_number"] == 2
assert self._do_raise.retry.statistics["idle_for"] == 1.0
Expand Down
2 changes: 1 addition & 1 deletion client/tests/utils/test_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_extract_tar(self, m_open: MagicMock) -> None:
invalid_member,
valid_member,
]
with self.assertRaises(Exception):
with self.assertRaisesRegex(Exception, "Attempted path traversal in tar file"):
extract_tar(tar_path, target_dir, force=True)

m_open.reset_mock()
Expand Down
4 changes: 2 additions & 2 deletions example/LLM/llama/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def get_accelerate_model(
),
torch_dtype=torch_dtype,
)
setattr(model, "model_parallel", True)
setattr(model, "is_parallelizable", True)
model.model_parallel = True
model.is_parallelizable = True
model.config.torch_dtype = torch_dtype
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
model.gradient_checkpointing_enable()
Expand Down
Loading

0 comments on commit 8e1edb8

Please sign in to comment.