diff --git a/grand/backends/_sqlbackend.py b/grand/backends/_sqlbackend.py index 2f71493..fea4ef1 100644 --- a/grand/backends/_sqlbackend.py +++ b/grand/backends/_sqlbackend.py @@ -1,11 +1,10 @@ -from typing import Hashable, Generator, Optional, Iterable +from typing import Hashable, Generator import time import pandas as pd import sqlalchemy -from sqlalchemy.pool import NullPool -from sqlalchemy.sql import select -from sqlalchemy import and_, or_, func, Index +from sqlalchemy.sql import delete, select +from sqlalchemy import or_, func, Index from .backend import Backend @@ -192,6 +191,29 @@ def _upsert_node(self, node_name: Hashable, metadata: dict) -> Hashable: parameters={self._primary_key: node_name, "_metadata": metadata}, ) + def remove_node(self, u: Hashable) -> None: + """ + Removes nodes and related edges for name. + + Args: + u (Hashable): id of the node + """ + + # Remove nodes + statement = delete(self._node_table).where( + self._node_table.c[self._primary_key] == str(u) + ) + self._connection.execute(statement) + + # Remove edges for node + statement = delete(self._edge_table).where( + or_( + self._edge_table.c[self._edge_source_key] == str(u), + self._edge_table.c[self._edge_target_key] == str(u) + ) + ) + self._connection.execute(statement) + def all_nodes_as_iterable(self, include_metadata: bool = False) -> Generator: """ Get a generator of all of the nodes in this graph. @@ -233,7 +255,7 @@ def has_node(self, u: Hashable) -> bool: self._node_table.c[self._primary_key] == str(u) ) ).fetchall() - ) + ) > 0 def add_edge(self, u: Hashable, v: Hashable, metadata: dict): """ diff --git a/grand/backends/test_backends.py b/grand/backends/test_backends.py index cb0d098..786e263 100644 --- a/grand/backends/test_backends.py +++ b/grand/backends/test_backends.py @@ -151,6 +151,17 @@ def test_sqlite_persistence(self): nodes = list(backend.all_nodes_as_iterable()) # assert assert node0 in nodes + + # test remove_node + backend = SQLBackend(db_url=url, directed=True) + node1, node2 = backend.add_node("A", {}), backend.add_node("B", {}) + backend.add_edge(node1, node2, {}) + assert backend.has_node(node1) + assert backend.has_edge(node1, node2) + backend.remove_node(node1) + assert not backend.has_node(node1) + assert not backend.has_edge(node1, node2) + # cleanup os.remove(dbpath) diff --git a/setup.py b/setup.py index 3050951..a0027ee 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ "sql": ["SQLAlchemy>=1.3"], "dynamodb": ["boto3"], "igraph": ["igraph"], - "networkit": ["cmake", "cython", "networkit"], + "networkit": ["cmake", "cython", "networkit", "numpy<2.0.0"], }, classifiers=[ "Programming Language :: Python :: 3",