Skip to content

Commit

Permalink
Merge pull request #204 from SpinW/implment_bounds_in_ndbase_minimisers
Browse files Browse the repository at this point in the history
Refactor simplex to use wrapper class handling bound and fixed parameters
  • Loading branch information
RichardWaiteSTFC authored Nov 1, 2024
2 parents e41f848 + f60fc5c commit 73f604a
Show file tree
Hide file tree
Showing 6 changed files with 515 additions and 264 deletions.
135 changes: 135 additions & 0 deletions +sw_tests/+unit_tests/unittest_ndbase_cost_function_wrapper.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
classdef unittest_ndbase_cost_function_wrapper < sw_tests.unit_tests.unittest_super
% Runs through unit test for ndbase optimisers, atm only simplex passes
% these tests

properties
fcost = @(p) (p(1)-1)^2 + (p(2)-2)^2
params = [2,4]
end

properties (TestParameter)
bound_param_name = {'lb', 'ub'}
no_lower_bound = {[], [-inf, -inf], [NaN, -inf]};
no_upper_bound = {[], [inf, inf], [inf, NaN]};
errors = {ones(1,3), [], zeros(1,3), 'NoField'}
end

methods
function [pfree, pbound, cost_val] = get_pars_and_cost_val(testCase, cost_func_wrap)
pfree = cost_func_wrap.get_free_parameters(testCase.params);
pbound = cost_func_wrap.get_bound_parameters(pfree);
cost_val = cost_func_wrap.eval_cost_function(pfree);
end
end

methods (Test)

function test_init_with_fcost_no_bounds(testCase)
cost_func_wrap = ndbase.cost_function_wrapper(testCase.fcost, testCase.params);
[pfree, pbound, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(pfree, testCase.params);
testCase.verify_val(pbound, testCase.params);
testCase.verify_val(cost_val, testCase.fcost(pbound), 'abs_tol', 1e-4);
end

function test_init_with_fcost_no_bounds_name_value_passed(testCase, no_lower_bound, no_upper_bound)
cost_func_wrap = ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'lb', no_lower_bound, 'ub', no_upper_bound);
[pfree, pbound, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(pfree, testCase.params);
testCase.verify_val(pbound, testCase.params);
testCase.verify_val(cost_val, testCase.fcost(pbound), 'abs_tol', 1e-4);
end

function test_init_with_fcost_lower_bound_only(testCase)
% note first param outside bounds
cost_func_wrap = ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'lb', [3, 1]);
[pfree, pbound, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(pfree, [0, 3.8730], 'abs_tol', 1e-4);
testCase.verify_val(pbound, [3, 4], 'abs_tol', 1e-4);
testCase.verify_val(cost_val, testCase.fcost(pbound), 'abs_tol', 1e-4);
end

function test_init_with_fcost_upper_bound_only(testCase)
% note second param outside bounds
cost_func_wrap = ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'ub', [3, 1]);
[pfree, pbound, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(pfree, [1.7320, 0], 'abs_tol', 1e-4);
testCase.verify_val(pbound, [2, 1], 'abs_tol', 1e-4);
testCase.verify_val(cost_val, testCase.fcost(pbound), 'abs_tol', 1e-4);
end

function test_init_with_fcost_both_bounds(testCase)
% note second param outside bounds
cost_func_wrap = ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'lb', [1, 2], 'ub', [3, 2.5]);
[pfree, pbound, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(pfree, [0, 1.5708], 'abs_tol', 1e-4);
testCase.verify_val(pbound, [2, 2.5], 'abs_tol', 1e-4);
testCase.verify_val(cost_val, testCase.fcost(pbound), 'abs_tol', 1e-4);
end

function test_init_with_fcost_both_bounds_with_fixed_param(testCase)
% note second param outside bounds
cost_func_wrap = ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'lb', [1, 2.5], 'ub', [3, 2.5]);
[pfree, pbound, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(pfree, 0, 'abs_tol', 1e-4); % only first param free
testCase.verify_val(pbound, [2, 2.5], 'abs_tol', 1e-4);
testCase.verify_val(cost_val, testCase.fcost(pbound), 'abs_tol', 1e-4);
testCase.verify_val(cost_func_wrap.ifixed, 2);
testCase.verify_val(cost_func_wrap.ifree, 1);
testCase.verify_val(cost_func_wrap.pars_fixed, 2.5);
end


function test_init_with_fcost_both_bounds_with_fixed_param_using_ifix(testCase)
% note second param outside bounds
cost_func_wrap = ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'lb', [1, 2], 'ub', [3, 2.5], 'ifix', [2]);
[pfree, pbound, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(pfree, 0, 'abs_tol', 1e-4); % only first param free
testCase.verify_val(pbound, [2, 2.5], 'abs_tol', 1e-4);
testCase.verify_val(cost_val, testCase.fcost(pbound), 'abs_tol', 1e-4);
testCase.verify_val(cost_func_wrap.ifixed, 2);
testCase.verify_val(cost_func_wrap.ifree, 1);
testCase.verify_val(cost_func_wrap.pars_fixed, 2.5);
end

function test_init_with_fcost_no_bounds_with_fixed_param_using_ifix(testCase)
% note second param outside bounds
cost_func_wrap = ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'ifix', [2]);
[pfree, pbound, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(pfree, testCase.params(1), 'abs_tol', 1e-4); % only first param free
testCase.verify_val(pbound, testCase.params, 'abs_tol', 1e-4);
testCase.verify_val(cost_func_wrap.ifixed, 2);
testCase.verify_val(cost_func_wrap.ifree, 1);
testCase.verify_val(cost_func_wrap.pars_fixed, testCase.params(2));
end

function test_init_with_data(testCase, errors)
% all errors passed lead to unweighted residuals (either as
% explicitly ones or the default weights if invalid errors)
if ischar(errors) && errors == "NoField"
dat = struct('x', 1:3);
else
dat = struct('x', 1:3, 'e', errors);
end
dat.y = polyval(testCase.params, dat.x);
cost_func_wrap = ndbase.cost_function_wrapper(@(x, p) polyval(p, x), testCase.params, 'data', dat);
[pfree, pbound, cost_val] = testCase.get_pars_and_cost_val(cost_func_wrap);
testCase.verify_val(pfree, testCase.params, 'abs_tol', 1e-4);
testCase.verify_val(pbound, testCase.params, 'abs_tol', 1e-4);
testCase.verify_val(cost_val, 0, 'abs_tol', 1e-4);
end

function test_wrong_size_bounds(testCase, bound_param_name)
testCase.verifyError(...
@() ndbase.cost_function_wrapper(testCase.fcost, testCase.params, bound_param_name, ones(3)), ...
'ndbase:cost_function_wrapper:WrongInput');
end

function test_incompatible_bounds(testCase)
testCase.verifyError(...
@() ndbase.cost_function_wrapper(testCase.fcost, testCase.params, 'lb', [1,1,], 'ub', [0,0]), ...
'ndbase:cost_function_wrapper:WrongInput');
end

end
end
86 changes: 86 additions & 0 deletions +sw_tests/+unit_tests/unittest_ndbase_optimisers.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
classdef unittest_ndbase_optimisers < sw_tests.unit_tests.unittest_super
% Runs through unit test for ndbase optimisers using bounded parameter
% transformations.

properties
rosenbrock = @(x) (1-x(1)).^2 + 100*(x(2) - x(1).^2).^2;
rosenbrock_minimum = [1, 1];
end

properties (TestParameter)
optimiser = {@ndbase.simplex};
poly_func = {@(x, p) polyval(p, x), '@(x, p) polyval(p, x)'}
end

methods (Test)
function test_optimise_data_struct(testCase, optimiser, poly_func)
linear_pars = [2, 1];
dat = struct('x', 1:3, 'e', ones(1,3));
dat.y = polyval(linear_pars, dat.x);
[pars_fit, cost_val, ~] = optimiser(dat, poly_func, [-1,-1]);
testCase.verify_val(pars_fit, linear_pars, 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 2.5e-7, 'abs_tol', 1e-8);
end

function test_optimise_rosen_free(testCase, optimiser)
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1]);
testCase.verify_val(pars_fit, testCase.rosenbrock_minimum, 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 0, 'abs_tol', 1e-6);
end

function test_optimise_rosen_lower_bound_minimum_accessible(testCase, optimiser)
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'lb', [-2, -2]);
testCase.verify_val(pars_fit, testCase.rosenbrock_minimum, 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 0, 'abs_tol', 1e-6);
end

function test_optimise_rosen_lower_bound_minimum_not_accessible(testCase, optimiser)
% note intital guess is outside bounds
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'lb', [-inf, 2]);
testCase.verify_val(pars_fit, [-1.411, 2], 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 5.821, 'abs_tol', 1e-3);
end

function test_optimise_rosen_upper_bound_minimum_accessible(testCase, optimiser)
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'ub', [2, 2]);
testCase.verify_val(pars_fit, testCase.rosenbrock_minimum, 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 0, 'abs_tol', 1e-6);
end

function test_optimise_rosen_upper_bound_minimum_not_accessible(testCase, optimiser)
% note intital guess is outside bounds
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'ub', [0, inf]);
testCase.verify_val(pars_fit, [0, 0], 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 1, 'abs_tol', 1e-6);
end

function test_optimise_rosen_both_bounds_minimum_accessible(testCase, optimiser)
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'lb', [-2, -2], 'ub', [2, 2]);
testCase.verify_val(pars_fit, testCase.rosenbrock_minimum, 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 0, 'abs_tol', 1e-6);
end

function test_optimise_rosen_both_bounds_minimum_not_accessible(testCase, optimiser)
% note intital guess is outside bounds
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'lb', [-0.5, -0.5], 'ub', [0, 0]);
testCase.verify_val(pars_fit, [0, 0], 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 1, 'abs_tol', 1e-6);
end

function test_optimise_rosen_parameter_fixed_minimum_not_accessible(testCase, optimiser)
% note intital guess is outside bounds
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'lb', [0, -0.5], 'ub', [0, 0]);
testCase.verify_val(pars_fit, [0, 0], 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 1, 'abs_tol', 1e-6);
end

function test_optimise_rosen_parameter_all_fixed(testCase, optimiser)
% note intital guess is outside bounds
[pars_fit, cost_val, ~] = optimiser([], testCase.rosenbrock, [-1,-1], 'lb', [0, 0], 'ub', [0, 0]);
testCase.verify_val(pars_fit, [0, 0], 'abs_tol', 1e-3);
testCase.verify_val(cost_val, 1, 'abs_tol', 1e-6);
end

end
end
% do parameterised test with all minimisers
8 changes: 4 additions & 4 deletions +sw_tests/+unit_tests/unittest_spinw_fitspec.m
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ function del_mode_file(testCase)
methods (Test)
function test_fitspec(testCase)
fitout = testCase.swobj.fitspec(testCase.fitpar);
testCase.verify_val(fitout.x, 1.0, 'abs_tol', 0.25);
testCase.verify_val(fitout.redX2, 0.0, 'abs_tol', 10);
testCase.verify_val(fitout.x, 0.7, 'abs_tol', 0.25);
testCase.verify_val(fitout.redX2, 243, 'abs_tol', 10);
end
function test_fitspec_twin(testCase)
% Checks that twins are handled correctly
Expand All @@ -50,8 +50,8 @@ function test_fitspec_twin(testCase)
% If twins not handled correctly, the fit will be bad.
swobj.addtwin('axis', [1 1 1], 'phid', 54, 'vol', 0.01);
fitout = swobj.fitspec(testCase.fitpar);
testCase.verify_val(fitout.x, 1.0, 'abs_tol', 0.25);
testCase.verify_val(fitout.redX2, 0.0, 'abs_tol', 10);
testCase.verify_val(fitout.x, 0.7, 'abs_tol', 0.25);
testCase.verify_val(fitout.redX2, 243, 'abs_tol', 10);
end
end
end
4 changes: 2 additions & 2 deletions .github/workflows/build_pyspinw.yml
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,14 @@ jobs:
run: |
pip install scipy
cd ${{ github.workspace }}/python
pip install build/*whl
pip install wheelhouse/*whl
cd tests
python -m unittest
- name: Create wheel artifact
uses: actions/upload-artifact@v4
with:
name: pySpinW Wheel
path: ${{ github.workspace }}/python/build/*.whl
path: ${{ github.workspace }}/python/wheelhouse/*.whl
- name: Upload release wheels
if: ${{ github.event_name == 'release' }}
run: |
Expand Down
Loading

0 comments on commit 73f604a

Please sign in to comment.