Skip to content

Commit

Permalink
Add test case for passing a callable into an aggregate exp fragment
Browse files Browse the repository at this point in the history
  • Loading branch information
hartytp committed Dec 9, 2024
1 parent df0d018 commit ce2268c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
8 changes: 7 additions & 1 deletion test/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,13 @@ class AddOneAggregate(AggregateExpFragment):
def build_fragment(self) -> None:
self.setattr_fragment("a", AddOneFragment)
self.setattr_fragment("b", AddOneFragment)
return super().build_fragment([self.a, self.b])

self.setattr_result("sum", FloatChannel)

def push_sum():
self.sum.push(self.a.value.get() + 1 + self.b.value.get() + 1)

return super().build_fragment([self.a, self.b, push_sum])


class TrivialKernelAggregate(AggregateExpFragment):
Expand Down
1 change: 1 addition & 0 deletions test/test_experiment_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_aggregate(self):
self.assertEqual(parent.b.num_prepare_calls, 1)
self.assertEqual(result[parent.a.result], 1.0)
self.assertEqual(result[parent.b.result], 1.0)
self.assertEqual(result[parent.sum], 2.0)

def test_kernel(self):
parent = self.create(TrivialKernelAggregate, [])
Expand Down

0 comments on commit ce2268c

Please sign in to comment.