Skip to content

Commit

Permalink
Make CompoundRegion behave like Python iterable.
Browse files Browse the repository at this point in the history
  • Loading branch information
andy-slac committed Nov 20, 2024
1 parent da67747 commit d6791ae
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
31 changes: 31 additions & 0 deletions python/lsst/sphgeom/_compoundRegion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,42 @@ std::unique_ptr<_CompoundRegion> _args_factory(const py::args& args) {
return std::make_unique<_CompoundRegion>(std::move(operands));
}

// Iterator for CompoundRegion.
class CompoundIterator {
public:
CompoundIterator(CompoundRegion const& compound, size_t pos) : _compound(compound), _pos(pos) {}

Region const& operator*() const { return _compound.getOperand(_pos); }

CompoundIterator& operator++() {
++ _pos;
return *this;
}

bool operator==(CompoundIterator const& other) const {
return &_compound == &other._compound and _pos == other._pos;
}

private:
CompoundRegion const& _compound;
size_t _pos;
};

} // namespace

template <>
void defineClass(py::class_<CompoundRegion, std::unique_ptr<CompoundRegion>, Region> &cls) {
cls.def("nOperands", &CompoundRegion::nOperands);
cls.def("__len__", &CompoundRegion::nOperands);
cls.def(
"__iter__",
[](CompoundRegion const& region) {
return py::make_iterator(
CompoundIterator(region, 0U), CompoundIterator(region, region.nOperands())
);
},
py::return_value_policy::reference_internal // Keeps region alive while iterator is in use.
);
cls.def(
"cloneOperand",
[](CompoundRegion const &self, std::ptrdiff_t n) {
Expand Down
9 changes: 9 additions & 0 deletions tests/test_CompoundRegion.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ def testOperands(self):
"""Test the cloneOperands accessor."""
self.assertOperandsEqual(self.instance, self.operands)

def testIterator(self):
"""Test Python iteration."""
self.assertEqual(len(self.instance), len(self.operands))
it = iter(self.instance)
self.assertEqual(next(it), self.operands[0])
self.assertEqual(next(it), self.operands[1])
with self.assertRaises(StopIteration):
next(it)

def testCodec(self):
"""Test that encode and decode round-trip."""
s = self.instance.encode()
Expand Down

0 comments on commit d6791ae

Please sign in to comment.