Skip to content

Commit

Permalink
Sky model tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lmachadopolettivalle committed Nov 16, 2023
1 parent d96509f commit 67111eb
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions karabo/test/test_skymodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,6 @@
from karabo.simulation.sky_model import Polarisation, SkyModel


def test_init(sky_data_with_ids: NDArray[np.object_]):
sky1 = SkyModel()
sky1.add_point_sources(sky_data_with_ids)
sky2 = SkyModel(sky_data_with_ids)
# test if sources are inside now (-1 because ids are in `xarray.DataArray.coord`)
assert sky_data_with_ids.shape[1] - 1 == sky1.sources.shape[1]
assert sky_data_with_ids.shape[1] - 1 == sky2.sources.shape[1]


def test_not_full_array():
sky1 = SkyModel()
sky_data = xr.DataArray([[20.0, -30.0, 1], [20.0, -30.5, 3], [20.5, -30.5, 3]])
sky1.add_point_sources(sky_data)
sky2 = SkyModel(sky_data)
# test if doc shape were expanded
assert sky1.sources.shape == (sky_data.shape[0], 12)
assert sky2.sources.shape == (sky_data.shape[0], 12)


def test_filter_sky_model():
sky = SkyModel.get_GLEAM_Sky([76])
phase_center = [250, -80] # ra,dec
Expand All @@ -55,6 +36,25 @@ def test_filter_sky_model():
assert len(filtered_sky_euclidean_approx.sources) == len(filtered_sky.sources)


def test_init(sky_data_with_ids: NDArray[np.object_]):
sky1 = SkyModel()
sky1.add_point_sources(sky_data_with_ids)
sky2 = SkyModel(sky_data_with_ids)
# test if sources are inside now (-1 because ids are in `xarray.DataArray.coord`)
assert sky_data_with_ids.shape[1] - 1 == sky1.sources.shape[1]
assert sky_data_with_ids.shape[1] - 1 == sky2.sources.shape[1]


def test_not_full_array():
sky1 = SkyModel()
sky_data = xr.DataArray([[20.0, -30.0, 1], [20.0, -30.5, 3], [20.5, -30.5, 3]])
sky1.add_point_sources(sky_data)
sky2 = SkyModel(sky_data)
# test if doc shape were expanded
assert sky1.sources.shape == (sky_data.shape[0], 12)
assert sky2.sources.shape == (sky_data.shape[0], 12)


def test_filter_sky_model_h5():
sky = SkyModel.get_BATTYE_sky(which="diluted")
phase_center = [21.44213503, -30.70729488]
Expand Down

0 comments on commit 67111eb

Please sign in to comment.