From e0a883123b239d5ecbc3a439ea37a24fe2db983e Mon Sep 17 00:00:00 2001 From: Steven Murray Date: Fri, 5 Jul 2024 15:25:44 +0200 Subject: [PATCH] style: run pre-commit --- .github/workflows/warnings_tests.yaml | 2 +- src/linsolve/linsolve.py | 4 ++-- tests/test_linsolve.py | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/warnings_tests.yaml b/.github/workflows/warnings_tests.yaml index dec1e99..cf66876 100644 --- a/.github/workflows/warnings_tests.yaml +++ b/.github/workflows/warnings_tests.yaml @@ -20,4 +20,4 @@ jobs: - name: Run Tests run: | - pytest -W error \ No newline at end of file + pytest -W error diff --git a/src/linsolve/linsolve.py b/src/linsolve/linsolve.py index 89aa7cf..550c6c7 100644 --- a/src/linsolve/linsolve.py +++ b/src/linsolve/linsolve.py @@ -564,7 +564,7 @@ def _invert_solve(self, A, y, rcond): methods. """ # As of numpy 1.8, solve works on stacks of matrices - # Change in numpy 2.0: + # Change in numpy 2.0: # The b array is only treated as a shape (M,) column vector if it is # exactly 1-dimensional. In all other instances it is treated as a stack # of (M, K) matrices. Previously b would be treated as a stack of (M,) @@ -696,7 +696,7 @@ def eval(self, sol, keys=None): def _chisq(self, sol, data, wgts, evaluator): """Internal adaptable chisq calculator.""" if len(wgts) == 0: - sigma2 = {k: 1.0 for k in list(data.keys())} # equal weights + sigma2 = dict.fromkeys(data.keys(), value=1.0) # equal weights else: sigma2 = {k: wgts[k] ** -1 for k in list(wgts.keys())} evaluated = evaluator(sol, keys=data) diff --git a/tests/test_linsolve.py b/tests/test_linsolve.py index 64ae692..df14d97 100644 --- a/tests/test_linsolve.py +++ b/tests/test_linsolve.py @@ -490,7 +490,9 @@ def test_init(self): np.testing.assert_almost_equal(eval(k), 0.002) assert len(ls.ls.prms) == 3 - ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse, build_solver=False) + ls = linsolve.LinProductSolver( + d, sol0, w, sparse=self.sparse, build_solver=False + ) assert not hasattr(ls, "ls") assert ls.dtype == np.complex64