From d6791ae54a2c477689dddc72c5e47824c597fe87 Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Mon, 18 Nov 2024 15:41:50 -0800 Subject: [PATCH] Make CompoundRegion behave like Python iterable. --- python/lsst/sphgeom/_compoundRegion.cc | 31 ++++++++++++++++++++++++++ tests/test_CompoundRegion.py | 9 ++++++++ 2 files changed, 40 insertions(+) diff --git a/python/lsst/sphgeom/_compoundRegion.cc b/python/lsst/sphgeom/_compoundRegion.cc index 34a6e05..cc9cdc9 100644 --- a/python/lsst/sphgeom/_compoundRegion.cc +++ b/python/lsst/sphgeom/_compoundRegion.cc @@ -67,11 +67,42 @@ std::unique_ptr<_CompoundRegion> _args_factory(const py::args& args) { return std::make_unique<_CompoundRegion>(std::move(operands)); } +// Iterator for CompoundRegion. +class CompoundIterator { +public: + CompoundIterator(CompoundRegion const& compound, size_t pos) : _compound(compound), _pos(pos) {} + + Region const& operator*() const { return _compound.getOperand(_pos); } + + CompoundIterator& operator++() { + ++ _pos; + return *this; + } + + bool operator==(CompoundIterator const& other) const { + return &_compound == &other._compound and _pos == other._pos; + } + +private: + CompoundRegion const& _compound; + size_t _pos; +}; + } // namespace template <> void defineClass(py::class_, Region> &cls) { cls.def("nOperands", &CompoundRegion::nOperands); + cls.def("__len__", &CompoundRegion::nOperands); + cls.def( + "__iter__", + [](CompoundRegion const& region) { + return py::make_iterator( + CompoundIterator(region, 0U), CompoundIterator(region, region.nOperands()) + ); + }, + py::return_value_policy::reference_internal // Keeps region alive while iterator is in use. + ); cls.def( "cloneOperand", [](CompoundRegion const &self, std::ptrdiff_t n) { diff --git a/tests/test_CompoundRegion.py b/tests/test_CompoundRegion.py index 72fbad0..aee36d6 100644 --- a/tests/test_CompoundRegion.py +++ b/tests/test_CompoundRegion.py @@ -122,6 +122,15 @@ def testOperands(self): """Test the cloneOperands accessor.""" self.assertOperandsEqual(self.instance, self.operands) + def testIterator(self): + """Test Python iteration.""" + self.assertEqual(len(self.instance), len(self.operands)) + it = iter(self.instance) + self.assertEqual(next(it), self.operands[0]) + self.assertEqual(next(it), self.operands[1]) + with self.assertRaises(StopIteration): + next(it) + def testCodec(self): """Test that encode and decode round-trip.""" s = self.instance.encode()