Skip to content

Commit

Permalink
adding unit tests for parameter gorup
Browse files Browse the repository at this point in the history
  • Loading branch information
ConnorStoneAstro committed Jul 27, 2023
1 parent bb2aa21 commit 923154e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 2 deletions.
1 change: 1 addition & 0 deletions autophot/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .core_model import *
from .model_object import *
from .parameter_object import *
from .parameter_group import *
from .galaxy_model_object import *
from .ray_model import *
from .sersic_model import *
Expand Down
40 changes: 38 additions & 2 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import unittest
from autophot.models import Parameter
from autophot.models import Parameter, Parameter_Group
import torch
import numpy as np


class TestParameter(unittest.TestCase):
@ torch.no_grad()
@torch.no_grad()
def test_parameter_setting(self):
base_param = Parameter("base param")
base_param.set_value(1.0)
Expand Down Expand Up @@ -154,6 +155,41 @@ def test_parameter_state(self):

S = str(P)

class TestParameterGroup(unittest.TestCase):

def test_generation(self):
P = Parameter("state", value = 1., uncertainty = 0.5, limits = (-1, 1), locked = True, prof = 1.)

P2 = Parameter("v2")
P2.set_state(P.get_state())

PG = Parameter_Group("group", parameters = [P,P2])

PG_copy = PG.copy()

def test_vectors(self):
P1 = Parameter("test1", value = 1., uncertainty = 0.5, limits = (-1, 1), locked = False, prof = 1.)

P2 = Parameter("test2", value = 2., uncertainty = 5., limits = (None, 1), locked = False)

PG = Parameter_Group("group", parameters = [P1,P2])

names = PG.get_name_vector()
self.assertEqual(names, ["test1", "test2"], "get name vector should produce ordered list of names")

uncertainty = PG.get_uncertainty_vector()
self.assertTrue(np.all(uncertainty.detach().cpu().numpy() == np.array([0.5,5.])), "get uncertainty vector should track uncertainty")

def test_inspection(self):
P1 = Parameter("test1", value = 1., uncertainty = 0.5, limits = (-1, 1), locked = False, prof = 1.)

P2 = Parameter("test2", value = 2., uncertainty = 5., limits = (None, 1), locked = False)

PG = Parameter_Group("group", parameters = [P1,P2])

self.assertEqual(len(PG), 2, "parameter group should only have two parameters here")

string = str(PG)

if __name__ == "__main__":
unittest.main()
10 changes: 10 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,15 @@ def test_conversion_functions(self):
),
msg="Error computing inverse sersic function (torch)",
)
def test_general_derivative(self):

res = ap.utils.conversions.functions.general_uncertainty_prop(
tuple(torch.tensor(a) for a in (1.0, 1.0, 1.0, 0.5)),
tuple(torch.tensor(a) for a in (0.1, 0.1, 0.1, 0.1)),
ap.utils.conversions.functions.sersic_Ie_to_flux_torch,
)

self.assertAlmostEqual(res.detach().cpu().numpy(), 1.8105, 3, "General uncertianty prop should compute uncertainty")


class TestInterpolate(unittest.TestCase):
Expand Down Expand Up @@ -465,5 +474,6 @@ def test_angle_operation_functions(self):
)



if __name__ == "__main__":
unittest.main()

0 comments on commit 923154e

Please sign in to comment.