Skip to content

Commit

Permalink
Merge pull request #208 from SpinW/add_new_LM_minimiser
Browse files Browse the repository at this point in the history
Add new lm minimiser
  • Loading branch information
RichardWaiteSTFC authored Nov 18, 2024
2 parents 73f604a + 1fc202a commit 7e87dee
Show file tree
Hide file tree
Showing 8 changed files with 516 additions and 71 deletions.
51 changes: 41 additions & 10 deletions +sw_tests/+unit_tests/unittest_ndbase_cost_function_wrapper.m
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,26 @@ function test_init_with_fcost_lower_bound_only(testCase)
% note first param outside bounds
cost_func_wrap = ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'lb', [3, 1]);
[pfree, pbound, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(pfree, [0, 3.8730], 'abs_tol', 1e-4);
testCase.verify_val(pbound, [3, 4], 'abs_tol', 1e-4);
testCase.verify_val(pfree, [1.1180, 3.8730], 'abs_tol', 1e-4);
testCase.verify_val(pbound, [3.5, 4], 'abs_tol', 1e-4);
testCase.verify_val(cost_val, testCase.fcost(pbound), 'abs_tol', 1e-4);
end

function test_init_with_fcost_upper_bound_only(testCase)
% note second param outside bounds
cost_func_wrap = ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'ub', [3, 1]);
[pfree, pbound, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(pfree, [1.7320, 0], 'abs_tol', 1e-4);
testCase.verify_val(pbound, [2, 1], 'abs_tol', 1e-4);
testCase.verify_val(pfree, [1.7320, 1.1180], 'abs_tol', 1e-4);
testCase.verify_val(pbound, [2, 0.5], 'abs_tol', 1e-4);
testCase.verify_val(cost_val, testCase.fcost(pbound), 'abs_tol', 1e-4);
end

function test_init_with_fcost_both_bounds(testCase)
% note second param outside bounds
cost_func_wrap = ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'lb', [1, 2], 'ub', [3, 2.5]);
[pfree, pbound, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(pfree, [0, 1.5708], 'abs_tol', 1e-4);
testCase.verify_val(pbound, [2, 2.5], 'abs_tol', 1e-4);
testCase.verify_val(pfree, [0, 0], 'abs_tol', 1e-4);
testCase.verify_val(pbound, [2, 2.25], 'abs_tol', 1e-4);
testCase.verify_val(cost_val, testCase.fcost(pbound), 'abs_tol', 1e-4);
end

Expand All @@ -80,22 +80,31 @@ function test_init_with_fcost_both_bounds_with_fixed_param(testCase)
end


function test_init_with_fcost_both_bounds_with_fixed_param_using_ifix(testCase)
function test_init_with_fcost_both_bounds_fixed_invalid_param_using_ifix(testCase)
% note second param outside bounds
cost_func_wrap = ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'lb', [1, 2], 'ub', [3, 2.5], 'ifix', [2]);
[pfree, pbound, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(pfree, 0, 'abs_tol', 1e-4); % only first param free
testCase.verify_val(pbound, [2, 2.5], 'abs_tol', 1e-4);
testCase.verify_val(pbound, [2, 2.25], 'abs_tol', 1e-4);
testCase.verify_val(cost_val, testCase.fcost(pbound), 'abs_tol', 1e-4);
testCase.verify_val(cost_func_wrap.ifixed, 2);
testCase.verify_val(cost_func_wrap.ifree, 1);
testCase.verify_val(cost_func_wrap.pars_fixed, 2.5);
testCase.verify_val(cost_func_wrap.pars_fixed, 2.25);
end

function test_init_with_fcost_both_bounds_fixed_param_using_ifix(testCase)
% note second param outside bounds
cost_func_wrap = ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'lb', [1, 2], 'ub', [3, 6], 'ifix', [2]);
[pfree, pbound, ~] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(pfree, 0, 'abs_tol', 1e-4); % only first param free
testCase.verify_val(pbound, testCase.params, 'abs_tol', 1e-4);
testCase.verify_val(cost_func_wrap.pars_fixed, testCase.params(2));
end

function test_init_with_fcost_no_bounds_with_fixed_param_using_ifix(testCase)
% note second param outside bounds
cost_func_wrap = ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'ifix', [2]);
[pfree, pbound, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
[pfree, pbound, ~] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(pfree, testCase.params(1), 'abs_tol', 1e-4); % only first param free
testCase.verify_val(pbound, testCase.params, 'abs_tol', 1e-4);
testCase.verify_val(cost_func_wrap.ifixed, 2);
Expand Down Expand Up @@ -130,6 +139,28 @@ function test_incompatible_bounds(testCase)
@() ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'lb', [1,1,], 'ub', [0,0]), ...
'ndbase:cost_function_wrapper:WrongInput');
end

function test_init_with_resid_handle(testCase)
x = 1:3;
y = polyval(testCase.params, x);
cost_func_wrap = ndbase.cost_function_wrapper(@(p) y - polyval(p, x), testCase.params, 'resid_handle', true);
[~, ~, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(cost_val, 0, 'abs_tol', 1e-4);
end

function test_init_with_fcost_all_params_fixed(testCase)
% note second param outside bounds
ifixed = 1:2;
cost_func_wrap = ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'ifix', ifixed);
[pfree, pbound, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verifyEmpty(pfree);
testCase.verify_val(pbound, testCase.params);
testCase.verify_val(cost_val, testCase.fcost(pbound), 'abs_tol', 1e-4);
testCase.verify_val(cost_func_wrap.ifixed, ifixed);
testCase.verifyEmpty(cost_func_wrap.ifree);
testCase.verify_val(cost_func_wrap.pars_fixed, testCase.params);
end


end
end
39 changes: 30 additions & 9 deletions +sw_tests/+unit_tests/unittest_ndbase_optimisers.m
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
end

properties (TestParameter)
optimiser = {@ndbase.simplex};
optimiser = {@ndbase.simplex, @ndbase.lm4};
poly_func = {@(x, p) polyval(p, x), '@(x, p) polyval(p, x)'}
end

Expand All @@ -18,14 +18,23 @@ function test_optimise_data_struct(testCase, optimiser, poly_func)
dat = struct('x', 1:3, 'e', ones(1,3));
dat.y = polyval(linear_pars, dat.x);
[pars_fit, cost_val, ~] = optimiser(dat, poly_func, [-1,-1]);
testCase.verify_val(pars_fit, linear_pars, 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 2.5e-7, 'abs_tol', 1e-8);
testCase.verify_val(pars_fit, linear_pars, 'abs_tol', 2e-4);
testCase.verify_val(cost_val, 0, 'abs_tol', 5e-7);
end

function test_optimise_residual_array_lm(testCase, optimiser)
linear_pars = [2, 1];
x = 1:3;
y = polyval(linear_pars, x);
[pars_fit, cost_val, ~] = optimiser([], @(p) y - polyval(p, x), [-1,-1], 'resid_handle', true);
testCase.verify_val(pars_fit, linear_pars, 'abs_tol', 2e-4);
testCase.verify_val(cost_val, 0, 'abs_tol', 5e-7);
end

function test_optimise_rosen_free(testCase, optimiser)
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1]);
testCase.verify_val(pars_fit, testCase.rosenbrock_minimum, 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 0, 'abs_tol', 1e-6);
testCase.verify_val(cost_val, 0, 'abs_tol', 2e-7);
end

function test_optimise_rosen_lower_bound_minimum_accessible(testCase, optimiser)
Expand All @@ -48,10 +57,9 @@ function test_optimise_rosen_upper_bound_minimum_accessible(testCase, optimiser)
end

function test_optimise_rosen_upper_bound_minimum_not_accessible(testCase, optimiser)
% note intital guess is outside bounds
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'ub', [0, inf]);
testCase.verify_val(pars_fit, [0, 0], 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 1, 'abs_tol', 1e-6);
testCase.verify_val(cost_val, 1, 'abs_tol', 1e-4);
end

function test_optimise_rosen_both_bounds_minimum_accessible(testCase, optimiser)
Expand All @@ -62,7 +70,7 @@ function test_optimise_rosen_both_bounds_minimum_accessible(testCase, optimiser)

function test_optimise_rosen_both_bounds_minimum_not_accessible(testCase, optimiser)
% note intital guess is outside bounds
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'lb', [-0.5, -0.5], 'ub', [0, 0]);
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'lb', [-2, -2], 'ub', [0, 0]);
testCase.verify_val(pars_fit, [0, 0], 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 1, 'abs_tol', 1e-6);
end
Expand All @@ -74,13 +82,26 @@ function test_optimise_rosen_parameter_fixed_minimum_not_accessible(testCase, op
testCase.verify_val(cost_val, 1, 'abs_tol', 1e-6);
end

function test_optimise_rosen_parameter_fixed_minimum_not_accessible_with_vary_arg(testCase, optimiser)
% note intital guess is outside bounds
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [0,-1], 'lb', [nan, -0.5], 'ub', [nan, 0], 'vary', [false, true]);
testCase.verify_val(pars_fit, [0, 0], 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 1, 'abs_tol', 1e-6);
end

function test_optimise_rosen_parameter_all_fixed(testCase, optimiser)
% note intital guess is outside bounds
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'lb', [0, 0], 'ub', [0, 0]);
testCase.verify_val(pars_fit, [0, 0], 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 1, 'abs_tol', 1e-6);
end

function test_optimise_rosen_parameter_all_fixed_with_vary_arg(testCase, optimiser)
% note intital guess is outside bounds
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [0, 0], 'vary', [false, false]);
testCase.verify_val(pars_fit, [0, 0], 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 1, 'abs_tol', 1e-6);
end

end
end
% do parameterised test with all minimisers
end
8 changes: 6 additions & 2 deletions +sw_tests/+unit_tests/unittest_sw_fitpowder.m
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
default_modQ_cens_1d = 3.55:0.1:5.45; % integrate over nQ pts
end

properties (TestParameter)
fit_params = {{}, {'resid_handle', true}};
end

methods (TestClassSetup)
function setup_spinw_obj_and_expected_result(testCase)
% setup spinw object
Expand Down Expand Up @@ -321,13 +325,13 @@ function test_estimate_constant_background(testCase)
testCase.verify_results(out, expected_fitpow);
end

function test_fit_background(testCase)
function test_fit_background(testCase, fit_params)
out = sw_fitpowder(testCase.swobj, testCase.data_2d, ...
testCase.fit_func, testCase.j1);
out.y(1) = 10; % higher so other bins are background
out.fix_bg_parameters(1:2); % fix slopes of background to 0
out.set_bg_parameters(3, 1.5); % initial guess
out.fit_background()
out.fit_background(fit_params{:})
expected_fitpow = testCase.default_fitpow;
expected_fitpow.y(1) = 10;
expected_fitpow.ibg = [3;6;2;5;4];
Expand Down
Loading

0 comments on commit 7e87dee

Please sign in to comment.