diff --git a/crates/lox-orbits/src/python.rs b/crates/lox-orbits/src/python.rs index 49789935..9a7dc1fb 100644 --- a/crates/lox-orbits/src/python.rs +++ b/crates/lox-orbits/src/python.rs @@ -906,7 +906,7 @@ impl PyElevationMask { } fn __getnewargs__(&self) -> (Option>, Option>, Option) { - (self.azimuth(), self.elevation(), self.min_elevation()) + (self.azimuth(), self.elevation(), self.fixed_elevation()) } fn azimuth(&self) -> Option> { @@ -923,12 +923,16 @@ impl PyElevationMask { } } - fn min_elevation(&self) -> Option { + fn fixed_elevation(&self) -> Option { match &self.0 { ElevationMask::Fixed(min_elevation) => Some(*min_elevation), ElevationMask::Variable(_) => None, } } + + fn min_elevation(&self, azimuth: f64) -> f64 { + self.0.min_elevation(azimuth) + } } #[pyclass(name = "Observables", module = "lox_space", frozen)] diff --git a/crates/lox-space/lox_space.pyi b/crates/lox-space/lox_space.pyi index 438f517a..3a187233 100644 --- a/crates/lox-space/lox_space.pyi +++ b/crates/lox-space/lox_space.pyi @@ -11,9 +11,14 @@ class Ensemble: def __new__(cls, ensemble: dict[str, Trajectory]): ... class ElevationMask: - def __new__(cls, azimuth: np.ndarray, elevation: np.ndarray): ... + @classmethod + def variable(cls, azimuth: np.ndarray, elevation: np.ndarray) -> Self: ... @classmethod def fixed(cls, min_elevation: float) -> Self: ... + def azimuth(self) -> list[float] | None: ... + def elevation(self) -> list[float] | None: ... + def fixed_elevation(self) -> float | None: ... + def min_elevation(self, azimuth: float) -> float: ... def find_events( func: Callable[[float], float], start: Time, times: list[float] diff --git a/crates/lox-space/tests/test_ground.py b/crates/lox-space/tests/test_ground.py index 5d264f0b..0cf7bf39 100644 --- a/crates/lox-space/tests/test_ground.py +++ b/crates/lox-space/tests/test_ground.py @@ -20,3 +20,10 @@ def test_observables(): assert observables.range_rate() == pytest.approx(expected_range_rate, rel=1e-2) assert observables.azimuth() == pytest.approx(expected_azimuth, rel=1e-2) assert observables.elevation() == pytest.approx(expected_elevation, rel=1e-2) + + +def test_elevation_mask(): + mask = lox.ElevationMask.variable( + np.array([-np.pi, 0.0, np.pi]), np.array([0.0, 5.0, 0.0]) + ) + assert mask.min_elevation(np.pi / 2) == 2.5