diff --git a/groupyr/_copt/tests/test_proximal_gradient.py b/groupyr/_copt/tests/test_proximal_gradient.py index 0e8e8bc..ee38a24 100644 --- a/groupyr/_copt/tests/test_proximal_gradient.py +++ b/groupyr/_copt/tests/test_proximal_gradient.py @@ -4,7 +4,7 @@ import pytest from scipy import optimize -import groupyr._copt.loss as loss +from groupyr._copt.loss import LogLoss, SquareLoss, HuberLoss from groupyr._copt.proximal_gradient import minimize_proximal_gradient np.random.seed(0) @@ -24,7 +24,7 @@ def minimize_accelerated(*args, **kw): return minimize_proximal_gradient(*args, **kw) -loss_funcs = [loss.LogLoss, loss.SquareLoss, loss.HuberLoss] +loss_funcs = [LogLoss, SquareLoss, HuberLoss] def test_gradient(): @@ -33,9 +33,9 @@ def test_gradient(): b = np.random.rand(10) for loss in loss_funcs: f_grad = loss(A, b).f_grad - f = lambda x: f_grad(x)[0] - grad = lambda x: f_grad(x)[1] - eps = optimize.check_grad(f, grad, np.random.randn(5)) + eps = optimize.check_grad( + lambda x: f_grad(x)[0], lambda x: f_grad(x)[1], np.random.randn(5) + ) assert eps < 0.001 @@ -97,7 +97,7 @@ def test_callback(solver): def cb(_): return False - f = loss.SquareLoss(A, b) + f = SquareLoss(A, b) opt = solver(f.f_grad, np.zeros(n_features), callback=cb) assert opt.nit < 2 @@ -109,7 +109,7 @@ def test_line_search(solver): def ls_wrong(_): return -10 - ls_loss = loss.SquareLoss(A, b) + ls_loss = SquareLoss(A, b) # define a function with unused arguments for the API def f_grad(x, r1, r2):