From 1b2588158bf2a568441a2a3cbf8b9735de4d3079 Mon Sep 17 00:00:00 2001 From: Kristen Armes <6732445+kristenarmes@users.noreply.github.com> Date: Fri, 3 Sep 2021 11:20:30 -0700 Subject: [PATCH] feat: Adding ability to include conditions along with node and relation types in the neo4j staleness removal task (#1464) * Adding ability to include conditions along with node and relation types in the neo4j staleness removal task Signed-off-by: Kristen Armes * Addressing PR comments Signed-off-by: Kristen Armes * Fixing up query templates Signed-off-by: Kristen Armes * Fixing type issue in test Signed-off-by: Kristen Armes * Fixing formatting requirements Signed-off-by: Kristen Armes --- databuilder/README.md | 26 ++ .../task/neo4j_staleness_removal_task.py | 202 +++++++------ .../task/test_neo4j_staleness_removal_task.py | 267 ++++++++++++------ 3 files changed, 330 insertions(+), 165 deletions(-) diff --git a/databuilder/README.md b/databuilder/README.md index d2f4745645..9a70a9c8e5 100644 --- a/databuilder/README.md +++ b/databuilder/README.md @@ -1865,6 +1865,32 @@ You can think this approach as TTL based eviction. This is particularly useful w Above configuration is trying to delete stale usage relation (READ, READ_BY), by deleting READ or READ_BY relation that has not been published past 3 days. If number of elements to be removed is more than 10% per type, this task will be aborted without executing any deletion. +#### Using node and relation conditions to remove stale data +You may want to remove stale nodes and relations that meet certain conditions rather than all of a given type. To do this, you can specify the inputs to be a list of **TargetWithCondition** objects that each define a target type and a condition. Only stale nodes or relations of that type and that meet the condition will be removed when using this type of input. + +Node conditions can make use of the predefined variable `target` which represents the node. Relation conditions can include the variables `target`, `start_node`, and `end_node` where `target` represents the relation and `start_node`/`end_node` represent the nodes on either side of the target relation. For some examples of conditions see below. + + from databuilder.task.neo4j_staleness_removal_task import TargetWithCondition + + task = Neo4jStalenessRemovalTask() + job_config_dict = { + 'job.identifier': 'remove_stale_data_job', + 'task.remove_stale_data.neo4j_endpoint': neo4j_endpoint, + 'task.remove_stale_data.neo4j_user': neo4j_user, + 'task.remove_stale_data.neo4j_password': neo4j_password, + 'task.remove_stale_data.staleness_max_pct': 10, + 'task.remove_stale_data.target_nodes': [TargetWithCondition('Table', '(target)-[:COLUMN]->(:Column)'), # All Table nodes that have a directional COLUMN relation to a Column node + TargetWithCondition('Column', '(target)-[]-(:Table) AND target.name=\'column_name\'')], # All Column nodes named 'column_name' that have some relation to a Table node + 'task.remove_stale_data.target_relations': [TargetWithCondition('COLUMN', '(start_node:Table)-[target]->(end_node:Column)'), # All COLUMN relations that connect from a Table node to a Column node + TargetWithCondition('COLUMN', '(start_node:Column)-[target]-(end_node)')], # All COLUMN relations that connect any direction between a Column node and another node + 'task.remove_stale_data.milliseconds_to_expire': 86400000 * 3 + } + job_config = ConfigFactory.from_dict(job_config_dict) + job = DefaultJob(conf=job_config, task=task) + job.launch() + +You can include multiple inputs of the same type with different conditions as seen in the **target_relations** list above. Attribute checks can also be added as shown in the **target_nodes** list. + #### Dry run Deletion is always scary and it's better to perform dryrun before put this into action. You can use Dry run to see what sort of Cypher query will be executed. diff --git a/databuilder/databuilder/task/neo4j_staleness_removal_task.py b/databuilder/databuilder/task/neo4j_staleness_removal_task.py index 738f9f2779..3b5db0cc82 100644 --- a/databuilder/databuilder/task/neo4j_staleness_removal_task.py +++ b/databuilder/databuilder/task/neo4j_staleness_removal_task.py @@ -5,7 +5,7 @@ import textwrap import time from typing import ( - Any, Dict, Iterable, + Any, Dict, Iterable, Union, ) import neo4j @@ -54,6 +54,12 @@ MARKER_VAR_NAME = 'marker' +class TargetWithCondition: + def __init__(self, target_type: str, condition: str) -> None: + self.target_type = target_type + self.condition = condition + + class Neo4jStalenessRemovalTask(Task): """ A Specific task that is to remove stale nodes and relations in Neo4j. @@ -64,6 +70,31 @@ class Neo4jStalenessRemovalTask(Task): """ + delete_stale_nodes_statement = textwrap.dedent(""" + MATCH (target:{{type}}) + WHERE {staleness_condition}{{extra_condition}} + WITH target LIMIT $batch_size + DETACH DELETE (target) + RETURN count(*) as count + """) + delete_stale_relations_statement = textwrap.dedent(""" + MATCH (start_node)-[target:{{type}}]-(end_node) + WHERE {staleness_condition}{{extra_condition}} + WITH target LIMIT $batch_size + DELETE target + RETURN count(*) as count + """) + validate_node_staleness_statement = textwrap.dedent(""" + MATCH (target:{{type}}) + WHERE {staleness_condition}{{extra_condition}} + RETURN count(*) as count + """) + validate_relation_staleness_statement = textwrap.dedent(""" + MATCH (start_node)-[target:{{type}}]-(end_node) + WHERE {staleness_condition}{{extra_condition}} + RETURN count(*) as count + """) + def __init__(self) -> None: pass @@ -123,14 +154,8 @@ def validate(self) -> None: self._validate_relation_staleness_pct() def _delete_stale_nodes(self) -> None: - statement = textwrap.dedent(""" - MATCH (n:{{type}}) - WHERE {} - WITH n LIMIT $batch_size - DETACH DELETE (n) - RETURN COUNT(*) as count; - """) - self._batch_delete(statement=self._decorate_staleness(statement), targets=self.target_nodes) + self._batch_delete(statement=self._decorate_staleness(self.delete_stale_nodes_statement), + targets=self.target_nodes) def _decorate_staleness(self, statement: str @@ -141,27 +166,21 @@ def _decorate_staleness(self, :return: """ if self.ms_to_expire: - return statement.format(textwrap.dedent(f""" - n.publisher_last_updated_epoch_ms < (timestamp() - ${MARKER_VAR_NAME}) - OR NOT EXISTS(n.publisher_last_updated_epoch_ms)""")) + return statement.format(staleness_condition=textwrap.dedent(f"""\ + (target.publisher_last_updated_epoch_ms < (timestamp() - ${MARKER_VAR_NAME}) + OR NOT EXISTS(target.publisher_last_updated_epoch_ms))""")) - return statement.format(textwrap.dedent(f""" - n.published_tag <> ${MARKER_VAR_NAME} - OR NOT EXISTS(n.published_tag)""")) + return statement.format(staleness_condition=textwrap.dedent(f"""\ + (target.published_tag <> ${MARKER_VAR_NAME} + OR NOT EXISTS(target.published_tag))""")) def _delete_stale_relations(self) -> None: - statement = textwrap.dedent(""" - MATCH ()-[n:{{type}}]-() - WHERE {} - WITH n LIMIT $batch_size - DELETE n - RETURN count(*) as count; - """) - self._batch_delete(statement=self._decorate_staleness(statement), targets=self.target_relations) + self._batch_delete(statement=self._decorate_staleness(self.delete_stale_relations_statement), + targets=self.target_relations) def _batch_delete(self, statement: str, - targets: Iterable[str] + targets: Union[Iterable[str], Iterable[TargetWithCondition]] ) -> None: """ Performing huge amount of deletion could degrade Neo4j performance. Therefore, it's taking batch deletion here. @@ -170,10 +189,18 @@ def _batch_delete(self, :return: """ for t in targets: - LOGGER.info('Deleting stale data of %s with batch size %i', t, self.batch_size) + if isinstance(t, TargetWithCondition): + target_type = t.target_type + extra_condition = ' AND ' + t.condition + else: + target_type = t + extra_condition = '' + + LOGGER.info('Deleting stale data of %s with batch size %i', target_type, self.batch_size) total_count = 0 while True: - results = self._execute_cypher_query(statement=statement.format(type=t), + results = self._execute_cypher_query(statement=statement.format(type=target_type, + extra_condition=extra_condition), param_dict={'batch_size': self.batch_size, MARKER_VAR_NAME: self.marker}, dry_run=self.dry_run) @@ -182,75 +209,78 @@ def _batch_delete(self, total_count = total_count + count if count == 0: break - LOGGER.info('Deleted %i stale data of %s', total_count, t) + LOGGER.info('Deleted %i stale data of %s', total_count, target_type) def _validate_staleness_pct(self, - total_records: Iterable[Dict[str, Any]], - stale_records: Iterable[Dict[str, Any]], - types: Iterable[str] + total_record_count: int, + stale_record_count: int, + target_type: str ) -> None: - total_count_dict = {record['type']: int(record['count']) for record in total_records} - - for record in stale_records: - type_str = record['type'] - if type_str not in types: - continue - - stale_count = record['count'] - if stale_count == 0: - continue + if total_record_count == 0 or stale_record_count == 0: + return - node_count = total_count_dict[type_str] - stale_pct = stale_count * 100 / node_count + stale_pct = stale_record_count * 100 / total_record_count - threshold = self.staleness_pct_dict.get(type_str, self.staleness_pct) - if stale_pct >= threshold: - raise Exception(f'Staleness percentage of {type_str} is {stale_pct} %. ' - f'Stopping due to over threshold {threshold} %') + threshold = self.staleness_pct_dict.get(target_type, self.staleness_pct) + if stale_pct >= threshold: + raise Exception(f'Staleness percentage of {target_type} is {stale_pct} %. ' + f'Stopping due to over threshold {threshold} %') def _validate_node_staleness_pct(self) -> None: - total_nodes_statement = textwrap.dedent(""" - MATCH (n) - WITH DISTINCT labels(n) as node, count(*) as count - RETURN head(node) as type, count - """) - - stale_nodes_statement = textwrap.dedent(""" - MATCH (n) - WHERE {} - WITH DISTINCT labels(n) as node, count(*) as count - RETURN head(node) as type, count - """) - - stale_nodes_statement = textwrap.dedent(self._decorate_staleness(stale_nodes_statement)) - - total_records = self._execute_cypher_query(statement=total_nodes_statement) - stale_records = self._execute_cypher_query(statement=stale_nodes_statement, - param_dict={MARKER_VAR_NAME: self.marker}) - self._validate_staleness_pct(total_records=total_records, - stale_records=stale_records, - types=self.target_nodes) + total_nodes_statement = textwrap.dedent( + self.validate_node_staleness_statement.format(staleness_condition='true')) + stale_nodes_statement = textwrap.dedent( + self._decorate_staleness(self.validate_node_staleness_statement)) + + for t in self.target_nodes: + if isinstance(t, TargetWithCondition): + target_type = t.target_type + extra_condition = ' AND ' + t.condition + else: + target_type = t + extra_condition = '' + + total_records = self._execute_cypher_query( + statement=total_nodes_statement.format(type=target_type, + extra_condition=extra_condition)) + stale_records = self._execute_cypher_query( + statement=stale_nodes_statement.format(type=target_type, + extra_condition=extra_condition), + param_dict={MARKER_VAR_NAME: self.marker}) + + total_record_value = next(iter(total_records), None) + stale_record_value = next(iter(stale_records), None) + self._validate_staleness_pct(total_record_count=total_record_value['count'] if total_record_value else 0, + stale_record_count=stale_record_value['count'] if stale_record_value else 0, + target_type=target_type) def _validate_relation_staleness_pct(self) -> None: - total_relations_statement = textwrap.dedent(""" - MATCH ()-[r]-() - RETURN type(r) as type, count(*) as count; - """) - - stale_relations_statement = textwrap.dedent(""" - MATCH ()-[n]-() - WHERE {} - RETURN type(n) as type, count(*) as count - """) - - stale_relations_statement = textwrap.dedent(self._decorate_staleness(stale_relations_statement)) - - total_records = self._execute_cypher_query(statement=total_relations_statement) - stale_records = self._execute_cypher_query(statement=stale_relations_statement, - param_dict={MARKER_VAR_NAME: self.marker}) - self._validate_staleness_pct(total_records=total_records, - stale_records=stale_records, - types=self.target_relations) + total_relations_statement = textwrap.dedent( + self.validate_relation_staleness_statement.format(staleness_condition='true')) + stale_relations_statement = textwrap.dedent( + self._decorate_staleness(self.validate_relation_staleness_statement)) + + for t in self.target_relations: + if isinstance(t, TargetWithCondition): + target_type = t.target_type + extra_condition = ' AND ' + t.condition + else: + target_type = t + extra_condition = '' + + total_records = self._execute_cypher_query( + statement=total_relations_statement.format(type=target_type, + extra_condition=extra_condition)) + stale_records = self._execute_cypher_query( + statement=stale_relations_statement.format(type=target_type, + extra_condition=extra_condition), + param_dict={MARKER_VAR_NAME: self.marker}) + + total_record_value = next(iter(total_records), None) + stale_record_value = next(iter(stale_records), None) + self._validate_staleness_pct(total_record_count=total_record_value['count'] if total_record_value else 0, + stale_record_count=stale_record_value['count'] if stale_record_value else 0, + target_type=target_type) def _execute_cypher_query(self, statement: str, diff --git a/databuilder/tests/unit/task/test_neo4j_staleness_removal_task.py b/databuilder/tests/unit/task/test_neo4j_staleness_removal_task.py index 08bd352137..e235236459 100644 --- a/databuilder/tests/unit/task/test_neo4j_staleness_removal_task.py +++ b/databuilder/tests/unit/task/test_neo4j_staleness_removal_task.py @@ -14,7 +14,7 @@ from databuilder.publisher import neo4j_csv_publisher from databuilder.task import neo4j_staleness_removal_task -from databuilder.task.neo4j_staleness_removal_task import Neo4jStalenessRemovalTask +from databuilder.task.neo4j_staleness_removal_task import Neo4jStalenessRemovalTask, TargetWithCondition class TestRemoveStaleData(unittest.TestCase): @@ -36,10 +36,12 @@ def test_validation_failure(self) -> None: }) task.init(job_config) - total_records = [{'type': 'foo', 'count': 100}] - stale_records = [{'type': 'foo', 'count': 50}] - targets = {'foo'} - task._validate_staleness_pct(total_records=total_records, stale_records=stale_records, types=targets) + total_record_count = 100 + stale_record_count = 50 + target_type = 'foo' + task._validate_staleness_pct(total_record_count=total_record_count, + stale_record_count=stale_record_count, + target_type=target_type) def test_validation(self) -> None: @@ -55,10 +57,10 @@ def test_validation(self) -> None: }) task.init(job_config) - total_records = [{'type': 'foo', 'count': 100}] - stale_records = [{'type': 'foo', 'count': 50}] - targets = {'foo'} - self.assertRaises(Exception, task._validate_staleness_pct, total_records, stale_records, targets) + total_record_count = 100 + stale_record_count = 50 + target_type = 'foo' + self.assertRaises(Exception, task._validate_staleness_pct, total_record_count, stale_record_count, target_type) def test_validation_threshold_override(self) -> None: @@ -75,12 +77,12 @@ def test_validation_threshold_override(self) -> None: }) task.init(job_config) - total_records = [{'type': 'foo', 'count': 100}, - {'type': 'bar', 'count': 100}] - stale_records = [{'type': 'foo', 'count': 50}, - {'type': 'bar', 'count': 3}] - targets = {'foo', 'bar'} - task._validate_staleness_pct(total_records=total_records, stale_records=stale_records, types=targets) + task._validate_staleness_pct(total_record_count=100, + stale_record_count=50, + target_type='foo') + task._validate_staleness_pct(total_record_count=100, + stale_record_count=3, + target_type='bar') def test_marker(self) -> None: with patch.object(GraphDatabase, 'driver'): @@ -123,6 +125,7 @@ def test_validation_statement_publish_tag(self) -> None: f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': ['Foo'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': ['BAR'], neo4j_csv_publisher.JOB_PUBLISH_TAG: 'foo', }) @@ -131,30 +134,27 @@ def test_validation_statement_publish_tag(self) -> None: mock_execute.assert_called() mock_execute.assert_any_call(statement=textwrap.dedent(""" - MATCH (n) - WITH DISTINCT labels(n) as node, count(*) as count - RETURN head(node) as type, count + MATCH (target:Foo) + WHERE true + RETURN count(*) as count """)) mock_execute.assert_any_call(param_dict={'marker': u'foo'}, statement=textwrap.dedent(""" - MATCH (n) - WHERE{} - n.published_tag <> $marker - OR NOT EXISTS(n.published_tag) - WITH DISTINCT labels(n) as node, count(*) as count - RETURN head(node) as type, count - """.format(' '))) + MATCH (target:Foo) + WHERE (target.published_tag <> $marker + OR NOT EXISTS(target.published_tag)) + RETURN count(*) as count + """)) task._validate_relation_staleness_pct() mock_execute.assert_any_call(param_dict={'marker': u'foo'}, statement=textwrap.dedent(""" - MATCH ()-[n]-() - WHERE{} - n.published_tag <> $marker - OR NOT EXISTS(n.published_tag) - RETURN type(n) as type, count(*) as count - """.format(' '))) + MATCH (start_node)-[target:BAR]-(end_node) + WHERE (target.published_tag <> $marker + OR NOT EXISTS(target.published_tag)) + RETURN count(*) as count + """)) def test_validation_statement_ms_to_expire(self) -> None: with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jStalenessRemovalTask, '_execute_cypher_query') \ @@ -166,6 +166,8 @@ def test_validation_statement_ms_to_expire(self) -> None: f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': ['Foo'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': ['BAR'], f'{task.get_scope()}.{neo4j_staleness_removal_task.MS_TO_EXPIRE}': 9876543210 }) @@ -174,30 +176,99 @@ def test_validation_statement_ms_to_expire(self) -> None: mock_execute.assert_called() mock_execute.assert_any_call(statement=textwrap.dedent(""" - MATCH (n) - WITH DISTINCT labels(n) as node, count(*) as count - RETURN head(node) as type, count + MATCH (target:Foo) + WHERE true + RETURN count(*) as count """)) mock_execute.assert_any_call(param_dict={'marker': 9876543210}, statement=textwrap.dedent(""" - MATCH (n) - WHERE{} - n.publisher_last_updated_epoch_ms < (timestamp() - $marker) - OR NOT EXISTS(n.publisher_last_updated_epoch_ms) - WITH DISTINCT labels(n) as node, count(*) as count - RETURN head(node) as type, count - """.format(' '))) + MATCH (target:Foo) + WHERE (target.publisher_last_updated_epoch_ms < (timestamp() - $marker) + OR NOT EXISTS(target.publisher_last_updated_epoch_ms)) + RETURN count(*) as count + """)) task._validate_relation_staleness_pct() mock_execute.assert_any_call(param_dict={'marker': 9876543210}, statement=textwrap.dedent(""" - MATCH ()-[n]-() - WHERE{} - n.publisher_last_updated_epoch_ms < (timestamp() - $marker) - OR NOT EXISTS(n.publisher_last_updated_epoch_ms) - RETURN type(n) as type, count(*) as count - """.format(' '))) + MATCH (start_node)-[target:BAR]-(end_node) + WHERE (target.publisher_last_updated_epoch_ms < (timestamp() - $marker) + OR NOT EXISTS(target.publisher_last_updated_epoch_ms)) + RETURN count(*) as count + """)) + + def test_validation_statement_with_target_condition(self) -> None: + with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jStalenessRemovalTask, '_execute_cypher_query') \ + as mock_execute: + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'foobar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': [TargetWithCondition('Foo', '(target)-[:BAR]->(:Foo) AND target.name=\'foo_name\'')], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': [TargetWithCondition('BAR', '(start_node:Foo)-[target]->(end_node:Foo)')], + neo4j_csv_publisher.JOB_PUBLISH_TAG: 'foo', + }) + + task.init(job_config) + task._validate_node_staleness_pct() + + mock_execute.assert_called() + mock_execute.assert_any_call(statement=textwrap.dedent(""" + MATCH (target:Foo) + WHERE true AND (target)-[:BAR]->(:Foo) AND target.name=\'foo_name\' + RETURN count(*) as count + """)) + + mock_execute.assert_any_call(param_dict={'marker': u'foo'}, + statement=textwrap.dedent(""" + MATCH (target:Foo) + WHERE (target.published_tag <> $marker + OR NOT EXISTS(target.published_tag)) AND (target)-[:BAR]->(:Foo) AND target.name=\'foo_name\' + RETURN count(*) as count + """)) + + task._validate_relation_staleness_pct() + mock_execute.assert_any_call(param_dict={'marker': u'foo'}, + statement=textwrap.dedent(""" + MATCH (start_node)-[target:BAR]-(end_node) + WHERE (target.published_tag <> $marker + OR NOT EXISTS(target.published_tag)) AND (start_node:Foo)-[target]->(end_node:Foo) + RETURN count(*) as count + """)) + + def test_validation_receives_correct_counts(self) -> None: + with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jStalenessRemovalTask, '_execute_cypher_query') \ + as mock_execute: + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'foobar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': ['Foo'], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': ['BAR'], + neo4j_csv_publisher.JOB_PUBLISH_TAG: 'foo', + }) + + task.init(job_config) + + with patch.object(Neo4jStalenessRemovalTask, '_validate_staleness_pct') as mock_validate: + mock_execute.side_effect = [[{'count': 100}], [{'count': 50}]] + task._validate_node_staleness_pct() + mock_validate.assert_called_with(total_record_count=100, + stale_record_count=50, + target_type='Foo') + + mock_execute.side_effect = [[{'count': 100}], [{'count': 50}]] + task._validate_relation_staleness_pct() + mock_validate.assert_called_with(total_record_count=100, + stale_record_count=50, + target_type='BAR') def test_delete_statement_publish_tag(self) -> None: with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jStalenessRemovalTask, '_execute_cypher_query') \ @@ -222,26 +293,24 @@ def test_delete_statement_publish_tag(self) -> None: mock_execute.assert_any_call(dry_run=False, param_dict={'marker': u'foo', 'batch_size': 100}, statement=textwrap.dedent(""" - MATCH (n:Foo) - WHERE{} - n.published_tag <> $marker - OR NOT EXISTS(n.published_tag) - WITH n LIMIT $batch_size - DETACH DELETE (n) - RETURN COUNT(*) as count; - """.format(' '))) + MATCH (target:Foo) + WHERE (target.published_tag <> $marker + OR NOT EXISTS(target.published_tag)) + WITH target LIMIT $batch_size + DETACH DELETE (target) + RETURN count(*) as count + """)) mock_execute.assert_any_call(dry_run=False, param_dict={'marker': u'foo', 'batch_size': 100}, statement=textwrap.dedent(""" - MATCH ()-[n:BAR]-() - WHERE{} - n.published_tag <> $marker - OR NOT EXISTS(n.published_tag) - WITH n LIMIT $batch_size - DELETE n - RETURN count(*) as count; - """.format(' '))) + MATCH (start_node)-[target:BAR]-(end_node) + WHERE (target.published_tag <> $marker + OR NOT EXISTS(target.published_tag)) + WITH target LIMIT $batch_size + DELETE target + RETURN count(*) as count + """)) def test_delete_statement_ms_to_expire(self) -> None: with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jStalenessRemovalTask, '_execute_cypher_query') \ @@ -266,26 +335,66 @@ def test_delete_statement_ms_to_expire(self) -> None: mock_execute.assert_any_call(dry_run=False, param_dict={'marker': 9876543210, 'batch_size': 100}, statement=textwrap.dedent(""" - MATCH (n:Foo) - WHERE{} - n.publisher_last_updated_epoch_ms < (timestamp() - $marker) - OR NOT EXISTS(n.publisher_last_updated_epoch_ms) - WITH n LIMIT $batch_size - DETACH DELETE (n) - RETURN COUNT(*) as count; - """.format(' '))) + MATCH (target:Foo) + WHERE (target.publisher_last_updated_epoch_ms < (timestamp() - $marker) + OR NOT EXISTS(target.publisher_last_updated_epoch_ms)) + WITH target LIMIT $batch_size + DETACH DELETE (target) + RETURN count(*) as count + """)) mock_execute.assert_any_call(dry_run=False, param_dict={'marker': 9876543210, 'batch_size': 100}, statement=textwrap.dedent(""" - MATCH ()-[n:BAR]-() - WHERE{} - n.publisher_last_updated_epoch_ms < (timestamp() - $marker) - OR NOT EXISTS(n.publisher_last_updated_epoch_ms) - WITH n LIMIT $batch_size - DELETE n - RETURN count(*) as count; - """.format(' '))) + MATCH (start_node)-[target:BAR]-(end_node) + WHERE (target.publisher_last_updated_epoch_ms < (timestamp() - $marker) + OR NOT EXISTS(target.publisher_last_updated_epoch_ms)) + WITH target LIMIT $batch_size + DELETE target + RETURN count(*) as count + """)) + + def test_delete_statement_with_target_condition(self) -> None: + with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jStalenessRemovalTask, '_execute_cypher_query') \ + as mock_execute: + mock_execute.return_value.single.return_value = {'count': 0} + task = Neo4jStalenessRemovalTask() + job_config = ConfigFactory.from_dict({ + f'job.identifier': 'remove_stale_data_job', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_END_POINT_KEY}': 'foobar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_USER}': 'foo', + f'{task.get_scope()}.{neo4j_staleness_removal_task.NEO4J_PASSWORD}': 'bar', + f'{task.get_scope()}.{neo4j_staleness_removal_task.STALENESS_MAX_PCT}': 5, + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_NODES}': [TargetWithCondition('Foo', '(target)-[:BAR]->(:Foo) AND target.name=\'foo_name\'')], + f'{task.get_scope()}.{neo4j_staleness_removal_task.TARGET_RELATIONS}': [TargetWithCondition('BAR', '(start_node:Foo)-[target]->(end_node:Foo)')], + neo4j_csv_publisher.JOB_PUBLISH_TAG: 'foo', + }) + + task.init(job_config) + task._delete_stale_nodes() + task._delete_stale_relations() + + mock_execute.assert_any_call(dry_run=False, + param_dict={'marker': u'foo', 'batch_size': 100}, + statement=textwrap.dedent(""" + MATCH (target:Foo) + WHERE (target.published_tag <> $marker + OR NOT EXISTS(target.published_tag)) AND (target)-[:BAR]->(:Foo) AND target.name=\'foo_name\' + WITH target LIMIT $batch_size + DETACH DELETE (target) + RETURN count(*) as count + """)) + + mock_execute.assert_any_call(dry_run=False, + param_dict={'marker': u'foo', 'batch_size': 100}, + statement=textwrap.dedent(""" + MATCH (start_node)-[target:BAR]-(end_node) + WHERE (target.published_tag <> $marker + OR NOT EXISTS(target.published_tag)) AND (start_node:Foo)-[target]->(end_node:Foo) + WITH target LIMIT $batch_size + DELETE target + RETURN count(*) as count + """)) def test_ms_to_expire_too_small(self) -> None: with patch.object(GraphDatabase, 'driver'):