Skip to content

Commit

Permalink
ML-303: In case key does not exist do not return features default val…
Browse files Browse the repository at this point in the history
…ues (#191)

* in case key does not exist do not return features default values

* pr comments
  • Loading branch information
katyakats authored Apr 4, 2021
1 parent 15afc6e commit 1224d2f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
44 changes: 43 additions & 1 deletion integration/test_aggregation_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,9 +1172,51 @@ def test_aggregate_multiple_keys(setup_teardown_test):
actual = controller.await_termination()
expected_results = [
{'number_of_stuff_sum_1h': 1.0, 'first_name': 'moshe', 'last_name': 'cohen', 'some_data': 4},
{'number_of_stuff_sum_1h': 0, 'first_name': 'moshe', 'last_name': 'levi', 'some_data': 5},
{'first_name': 'moshe', 'last_name': 'levi', 'some_data': 5},
{'number_of_stuff_sum_1h': 5.0, 'first_name': 'yosi', 'last_name': 'levi', 'some_data': 6}
]

assert actual == expected_results, \
f'actual did not match expected. \n actual: {actual} \n expected: {expected_results}'


def test_read_non_existing_key(setup_teardown_test):
data = pd.DataFrame(
{
"first_name": ["moshe", "yosi", "yosi"],
"last_name": ["cohen", "levi", "levi"],
"some_data": [1, 2, 3],
"time": [test_base_time - pd.Timedelta(minutes=25), test_base_time - pd.Timedelta(minutes=30),
test_base_time - pd.Timedelta(minutes=35)]
}
)

keys = 'first_name'
table = Table(setup_teardown_test, V3ioDriver())
controller = build_flow([
DataframeSource(data, key_field=keys),
AggregateByKey([FieldAggregator("number_of_stuff", "some_data", ["sum"],
SlidingWindows(['1h'], '10m'))],
table),
WriteToTable(table),
]).run()

actual = controller.await_termination()

other_table = Table(setup_teardown_test, V3ioDriver())
controller = build_flow([
Source(),
QueryByKey(["number_of_stuff_sum_1h"],
other_table, keys="first_name"),
Reduce([], lambda acc, x: append_return(acc, x)),
]).run()

controller.emit({'last_name': 'levi', 'some_data': 5}, 'non_existing_key')

controller.terminate()
actual = controller.await_termination()

print(actual[0])

assert "number_of_stuff_sum_1h" not in actual[0]

12 changes: 10 additions & 2 deletions storey/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ async def _get_features(self, key, timestamp):
if not self._schema:
await self._get_or_save_schema()

return self._get_aggregations_attrs(key).get_features(timestamp)
attrs = self._get_aggregations_attrs(key)

if attrs is None:
return {}

return attrs.get_features(timestamp)

def _new_aggregated_store_element(self):
if self._aggregations_read_only:
Expand All @@ -103,7 +108,10 @@ def _new_aggregated_store_element(self):
async def add_aggregation_by_key(self, key, base_timestamp, initial_data):
if not self._schema:
await self._get_or_save_schema()
self._set_aggregations_attrs(key, self._new_aggregated_store_element()(key, self._aggregates, base_timestamp, initial_data))
if self._aggregations_read_only and initial_data is None:
self._set_aggregations_attrs(key, None)
else:
self._set_aggregations_attrs(key, self._new_aggregated_store_element()(key, self._aggregates, base_timestamp, initial_data))

async def _get_or_save_schema(self):
self._schema = await self._storage._load_schema(self._container, self._table_path)
Expand Down

0 comments on commit 1224d2f

Please sign in to comment.