From 1224d2f4838445c8dfe877f43cb48dca334b9abe Mon Sep 17 00:00:00 2001 From: Katya Katsenelenbogen Date: Sun, 4 Apr 2021 09:56:48 +0300 Subject: [PATCH] ML-303: In case key does not exist do not return features default values (#191) * in case key does not exist do not return features default values * pr comments --- integration/test_aggregation_integration.py | 44 ++++++++++++++++++++- storey/table.py | 12 +++++- 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/integration/test_aggregation_integration.py b/integration/test_aggregation_integration.py index 29b29b3c..49eb458f 100644 --- a/integration/test_aggregation_integration.py +++ b/integration/test_aggregation_integration.py @@ -1172,9 +1172,51 @@ def test_aggregate_multiple_keys(setup_teardown_test): actual = controller.await_termination() expected_results = [ {'number_of_stuff_sum_1h': 1.0, 'first_name': 'moshe', 'last_name': 'cohen', 'some_data': 4}, - {'number_of_stuff_sum_1h': 0, 'first_name': 'moshe', 'last_name': 'levi', 'some_data': 5}, + {'first_name': 'moshe', 'last_name': 'levi', 'some_data': 5}, {'number_of_stuff_sum_1h': 5.0, 'first_name': 'yosi', 'last_name': 'levi', 'some_data': 6} ] assert actual == expected_results, \ f'actual did not match expected. \n actual: {actual} \n expected: {expected_results}' + + +def test_read_non_existing_key(setup_teardown_test): + data = pd.DataFrame( + { + "first_name": ["moshe", "yosi", "yosi"], + "last_name": ["cohen", "levi", "levi"], + "some_data": [1, 2, 3], + "time": [test_base_time - pd.Timedelta(minutes=25), test_base_time - pd.Timedelta(minutes=30), + test_base_time - pd.Timedelta(minutes=35)] + } + ) + + keys = 'first_name' + table = Table(setup_teardown_test, V3ioDriver()) + controller = build_flow([ + DataframeSource(data, key_field=keys), + AggregateByKey([FieldAggregator("number_of_stuff", "some_data", ["sum"], + SlidingWindows(['1h'], '10m'))], + table), + WriteToTable(table), + ]).run() + + actual = controller.await_termination() + + other_table = Table(setup_teardown_test, V3ioDriver()) + controller = build_flow([ + Source(), + QueryByKey(["number_of_stuff_sum_1h"], + other_table, keys="first_name"), + Reduce([], lambda acc, x: append_return(acc, x)), + ]).run() + + controller.emit({'last_name': 'levi', 'some_data': 5}, 'non_existing_key') + + controller.terminate() + actual = controller.await_termination() + + print(actual[0]) + + assert "number_of_stuff_sum_1h" not in actual[0] + diff --git a/storey/table.py b/storey/table.py index 4926f0db..c96ead09 100644 --- a/storey/table.py +++ b/storey/table.py @@ -93,7 +93,12 @@ async def _get_features(self, key, timestamp): if not self._schema: await self._get_or_save_schema() - return self._get_aggregations_attrs(key).get_features(timestamp) + attrs = self._get_aggregations_attrs(key) + + if attrs is None: + return {} + + return attrs.get_features(timestamp) def _new_aggregated_store_element(self): if self._aggregations_read_only: @@ -103,7 +108,10 @@ def _new_aggregated_store_element(self): async def add_aggregation_by_key(self, key, base_timestamp, initial_data): if not self._schema: await self._get_or_save_schema() - self._set_aggregations_attrs(key, self._new_aggregated_store_element()(key, self._aggregates, base_timestamp, initial_data)) + if self._aggregations_read_only and initial_data is None: + self._set_aggregations_attrs(key, None) + else: + self._set_aggregations_attrs(key, self._new_aggregated_store_element()(key, self._aggregates, base_timestamp, initial_data)) async def _get_or_save_schema(self): self._schema = await self._storage._load_schema(self._container, self._table_path)