Skip to content

Commit

Permalink
Linting in test.
Browse files Browse the repository at this point in the history
  • Loading branch information
arokem committed Sep 11, 2024
1 parent 94d5077 commit 1a6bdb7
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions groupyr/_copt/tests/test_proximal_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from scipy import optimize

import groupyr._copt.loss as loss
from groupyr._copt.loss import LogLoss, SquareLoss, HuberLoss
from groupyr._copt.proximal_gradient import minimize_proximal_gradient

np.random.seed(0)
Expand All @@ -24,7 +24,7 @@ def minimize_accelerated(*args, **kw):
return minimize_proximal_gradient(*args, **kw)


loss_funcs = [loss.LogLoss, loss.SquareLoss, loss.HuberLoss]
loss_funcs = [LogLoss, SquareLoss, HuberLoss]


def test_gradient():
Expand All @@ -33,9 +33,9 @@ def test_gradient():
b = np.random.rand(10)
for loss in loss_funcs:
f_grad = loss(A, b).f_grad
f = lambda x: f_grad(x)[0]
grad = lambda x: f_grad(x)[1]
eps = optimize.check_grad(f, grad, np.random.randn(5))
eps = optimize.check_grad(
lambda x: f_grad(x)[0], lambda x: f_grad(x)[1], np.random.randn(5)
)
assert eps < 0.001


Expand Down Expand Up @@ -97,7 +97,7 @@ def test_callback(solver):
def cb(_):
return False

f = loss.SquareLoss(A, b)
f = SquareLoss(A, b)
opt = solver(f.f_grad, np.zeros(n_features), callback=cb)
assert opt.nit < 2

Expand All @@ -109,7 +109,7 @@ def test_line_search(solver):
def ls_wrong(_):
return -10

ls_loss = loss.SquareLoss(A, b)
ls_loss = SquareLoss(A, b)

# define a function with unused arguments for the API
def f_grad(x, r1, r2):
Expand Down

0 comments on commit 1a6bdb7

Please sign in to comment.