diff --git a/.flake8 b/.flake8 index bfc43808..46cc5d86 100644 --- a/.flake8 +++ b/.flake8 @@ -1,8 +1,7 @@ [flake8] ignore = E501 W504 E731 per-file-ignores = - iw3/cli.py: E402 - iw3/gui.py: E402 + iw3/__init__.py: E402 waifu2x/__init__.py: E402 max_line_length = 128 diff --git a/iw3/__init__.py b/iw3/__init__.py new file mode 100644 index 00000000..e62d23d2 --- /dev/null +++ b/iw3/__init__.py @@ -0,0 +1,2 @@ +import os +os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' diff --git a/iw3/cli.py b/iw3/cli.py index 08d6cbcd..3a3ed069 100644 --- a/iw3/cli.py +++ b/iw3/cli.py @@ -1,5 +1,3 @@ -import os -os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' import torch from .utils import create_parser, set_state_args, iw3_main from . import models # noqa diff --git a/iw3/gui.py b/iw3/gui.py index cbcecd95..1914dcf8 100644 --- a/iw3/gui.py +++ b/iw3/gui.py @@ -1,5 +1,3 @@ -import os -os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' import nunif.pythonw_fix # noqa import locale import sys