Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update SQL backend to support SQLAlchemy 2.x #45

Merged
merged 3 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 115 additions & 84 deletions grand/backends/_sqlbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sqlalchemy
from sqlalchemy.pool import NullPool
from sqlalchemy.sql import select
from sqlalchemy import and_, or_, func
from sqlalchemy import and_, or_, func, Index

from .backend import Backend

Expand Down Expand Up @@ -60,51 +60,48 @@ def __init__(
self._connection = self._engine.connect()
self._metadata = sqlalchemy.MetaData()

if not self._engine.dialect.has_table(self._connection, self._node_table_name):
self._node_table = sqlalchemy.Table(
self._node_table_name,
self._metadata,
sqlalchemy.Column(
self._primary_key,
sqlalchemy.String(_DEFAULT_SQL_STR_LEN),
primary_key=True,
),
sqlalchemy.Column("_metadata", sqlalchemy.JSON),
)
self._node_table.create(self._engine)
else:
self._node_table = sqlalchemy.Table(
self._node_table_name,
self._metadata,
autoload=True,
autoload_with=self._engine,
)
# Create nodes table
self._node_table = sqlalchemy.Table(
self._node_table_name,
self._metadata,
sqlalchemy.Column(
self._primary_key,
sqlalchemy.String(_DEFAULT_SQL_STR_LEN),
primary_key=True,
),
sqlalchemy.Column("_metadata", sqlalchemy.JSON),
)
self._node_table.create(self._engine, checkfirst=True)

if not self._engine.dialect.has_table(self._connection, self._edge_table_name):
self._edge_table = sqlalchemy.Table(
self._edge_table_name,
self._metadata,
sqlalchemy.Column(
self._primary_key,
sqlalchemy.String(_DEFAULT_SQL_STR_LEN),
primary_key=True,
),
sqlalchemy.Column("_metadata", sqlalchemy.JSON),
sqlalchemy.Column(
self._edge_source_key, sqlalchemy.String(_DEFAULT_SQL_STR_LEN)
),
sqlalchemy.Column(
self._edge_target_key, sqlalchemy.String(_DEFAULT_SQL_STR_LEN)
),
)
self._edge_table.create(self._engine)
else:
self._edge_table = sqlalchemy.Table(
self._edge_table_name,
self._metadata,
autoload=True,
autoload_with=self._engine,
)
source_column = sqlalchemy.Column(
self._edge_source_key, sqlalchemy.String(_DEFAULT_SQL_STR_LEN)
)

target_column = sqlalchemy.Column(
self._edge_target_key, sqlalchemy.String(_DEFAULT_SQL_STR_LEN)
)

# Create edges table
self._edge_table = sqlalchemy.Table(
self._edge_table_name,
self._metadata,
sqlalchemy.Column(
self._primary_key,
sqlalchemy.String(_DEFAULT_SQL_STR_LEN),
primary_key=True,
),
sqlalchemy.Column("_metadata", sqlalchemy.JSON),
source_column,
target_column
)
self._edge_table.create(self._engine, checkfirst=True)

# Create source and target index
sindex = Index("edge_source", source_column)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌

sindex.create(self._engine, checkfirst=True)

tindex = Index("edge_target", target_column)
tindex.create(self._engine, checkfirst=True)

def is_directed(self) -> bool:
"""
Expand Down Expand Up @@ -147,17 +144,25 @@ def add_node(self, node_name: Hashable, metadata: dict) -> Hashable:
existing_metadata.update(metadata)
self._connection.execute(
self._node_table.update().where(
self._node_table.c[self._primary_key] == node_name
self._node_table.c[self._primary_key] == str(node_name)
),
**{"_metadata": existing_metadata},
parameters={"_metadata": existing_metadata},
)
else:
self._connection.execute(
self._node_table.insert(),
**{self._primary_key: node_name, "_metadata": metadata},
parameters={self._primary_key: node_name, "_metadata": metadata},
)
return node_name

def add_nodes_from(self, nodes_for_adding, **attr):
nodes = [{
self._primary_key: node,
"_metadata": {**attr, **metadata},
} for node, metadata in nodes_for_adding]

self._connection.execute(self._node_table.insert(), nodes)

def _upsert_node(self, node_name: Hashable, metadata: dict) -> Hashable:
"""
Add a new node to the graph, or update an existing one.
Expand All @@ -174,14 +179,14 @@ def _upsert_node(self, node_name: Hashable, metadata: dict) -> Hashable:
if node_exists:
self._connection.execute(
self._node_table.update().where(
self._node_table.c[self._primary_key] == node_name
self._node_table.c[self._primary_key] == str(node_name)
),
**{"_metadata": metadata},
parameters={"_metadata": metadata},
)
else:
self._connection.execute(
self._node_table.insert(),
**{self._primary_key: node_name, "_metadata": metadata},
parameters={self._primary_key: node_name, "_metadata": metadata},
)

def all_nodes_as_iterable(self, include_metadata: bool = False) -> Generator:
Expand All @@ -196,10 +201,16 @@ def all_nodes_as_iterable(self, include_metadata: bool = False) -> Generator:
Generator: A generator of all nodes (arbitrary sort)

"""
results = self._connection.execute(self._node_table.select()).fetchall()
if include_metadata:
return [(row[self._primary_key], row["_metadata"]) for row in results]
return [row[self._primary_key] for row in results]
sql = self._node_table.select()
else:
sql = self._node_table.select().with_only_columns(self._node_table.c[self._primary_key])

results = []
for x in self._connection.execute(sql):
results.append(x if include_metadata else x[0])

return results

def has_node(self, u: Hashable) -> bool:
"""
Expand All @@ -214,7 +225,7 @@ def has_node(self, u: Hashable) -> bool:
return len(
self._connection.execute(
self._node_table.select().where(
self._node_table.c[self._primary_key] == u
self._node_table.c[self._primary_key] == str(u)
)
).fetchall()
)
Expand Down Expand Up @@ -245,7 +256,7 @@ def add_edge(self, u: Hashable, v: Hashable, metadata: dict):
try:
self._connection.execute(
self._edge_table.insert(),
**{
parameters={
self._primary_key: pk,
self._edge_source_key: u,
self._edge_target_key: v,
Expand All @@ -260,11 +271,21 @@ def add_edge(self, u: Hashable, v: Hashable, metadata: dict):
self._edge_table.update().where(
self._edge_table.c[self._primary_key] == pk
),
**{"_metadata": existing_metadata},
parameters={"_metadata": existing_metadata},
)

return pk

def add_edges_from(self, ebunch_to_add, **attr):
edges = [{
self._primary_key: f"__{u}__{v}",
self._edge_source_key: u,
self._edge_target_key: v,
"_metadata": {**attr, **metadata},
} for u, v, metadata in ebunch_to_add]

self._connection.execute(self._edge_table.insert(), edges)

def all_edges_as_iterable(self, include_metadata: bool = False) -> Generator:
"""
Get a list of all edges in this graph, arbitrary sort.
Expand All @@ -274,16 +295,18 @@ def all_edges_as_iterable(self, include_metadata: bool = False) -> Generator:

Returns:
Generator: A generator of all edges (arbitrary sort)

"""
return iter(
[
(e.Source, e.Target, e._metadata)
if include_metadata
else (e.Source, e.Target)
for e in self._connection.execute(self._edge_table.select()).fetchall()
]
)

columns = [
self._node_table.c[self._edge_source_key],
self._node_table.c[self._edge_target_key]
]

if include_metadata:
columns.append(self._node_table.c["_metadata"])

sql = self._node_table.select().with_only_columns(columns)
return self._connection.execute(sql).fetchall()

def get_node_by_id(self, node_name: Hashable):
"""
Expand All @@ -296,10 +319,11 @@ def get_node_by_id(self, node_name: Hashable):
dict: The metadata associated with this node

"""

res = (
self._connection.execute(
self._node_table.select().where(
self._node_table.c[self._primary_key] == node_name
self._node_table.c[self._primary_key] == str(node_name)
)
)
.fetchone()
Expand Down Expand Up @@ -357,22 +381,25 @@ def get_node_neighbors(
Generator

"""

if self._directed:
res = self._connection.execute(
self._edge_table.select().where(
self._edge_table.c[self._edge_source_key] == u
self._edge_table.c[self._edge_source_key] == str(u)
)
).fetchall()
else:
res = self._connection.execute(
self._edge_table.select().where(
or_(
(self._edge_table.c[self._edge_source_key] == u),
(self._edge_table.c[self._edge_target_key] == u),
(self._edge_table.c[self._edge_source_key] == str(u)),
(self._edge_table.c[self._edge_target_key] == str(u)),
)
)
).fetchall()

res = [x._asdict() for x in res]

if include_metadata:
return {
(
Expand Down Expand Up @@ -410,19 +437,21 @@ def get_node_predecessors(
if self._directed:
res = self._connection.execute(
self._edge_table.select().where(
self._edge_table.c[self._edge_target_key] == u
self._edge_table.c[self._edge_target_key] == str(u)
)
).fetchall()
else:
res = self._connection.execute(
self._edge_table.select().where(
or_(
(self._edge_table.c[self._edge_target_key] == u),
(self._edge_table.c[self._edge_source_key] == u),
(self._edge_table.c[self._edge_target_key] == str(u)),
(self._edge_table.c[self._edge_source_key] == str(u)),
)
)
).fetchall()

res = [x._asdict() for x in res]

if include_metadata:
return {
(
Expand Down Expand Up @@ -456,7 +485,7 @@ def get_node_count(self) -> Iterable:

"""
return self._connection.execute(
select([func.count()]).select_from(self._node_table)
select(func.count()).select_from(self._node_table)
).scalar()

def out_degrees(self, nbunch=None):
Expand All @@ -474,30 +503,31 @@ def out_degrees(self, nbunch=None):
if nbunch is None:
where_clause = None
elif isinstance(nbunch, (list, tuple)):
where_clause = self._edge_table.c[self._edge_source_key].in_(nbunch)
where_clause = self._edge_table.c[self._edge_source_key].in_([str(x) for x in nbunch])
else:
# single node:
where_clause = self._edge_table.c[self._edge_source_key] == nbunch
where_clause = self._edge_table.c[self._edge_source_key] == str(nbunch)

if self._directed:
query = (
select([self._edge_table.c[self._edge_source_key], func.count()])
select(self._edge_table.c[self._edge_source_key], func.count())
.select_from(self._edge_table)
.group_by(self._edge_table.c[self._edge_source_key])
)
else:
query = (
select([self._edge_table.c[self._edge_source_key], func.count()])
select(self._edge_table.c[self._edge_source_key], func.count())
.select_from(self._edge_table)
.group_by(self._edge_table.c[self._edge_source_key])
)

if where_clause is not None:
query = query.where(where_clause)

results = [x._asdict() for x in self._connection.execute(query).fetchall()]
results = {
r[self._edge_source_key]: r[1]
for r in self._connection.execute(query).fetchall()
for r in results
}

if nbunch and not isinstance(nbunch, (list, tuple)):
Expand All @@ -519,30 +549,31 @@ def in_degrees(self, nbunch=None):
if nbunch is None:
where_clause = None
elif isinstance(nbunch, (list, tuple)):
where_clause = self._edge_table.c[self._edge_target_key].in_(nbunch)
where_clause = self._edge_table.c[self._edge_target_key].in_([str(x) for x in nbunch])
else:
# single node:
where_clause = self._edge_table.c[self._edge_target_key] == nbunch
where_clause = self._edge_table.c[self._edge_target_key] == str(nbunch)

if self._directed:
query = (
select([self._edge_table.c[self._edge_target_key], func.count()])
select(self._edge_table.c[self._edge_target_key], func.count())
.select_from(self._edge_table)
.group_by(self._edge_table.c[self._edge_target_key])
)
else:
query = (
select([self._edge_table.c[self._edge_target_key], func.count()])
select(self._edge_table.c[self._edge_target_key], func.count())
.select_from(self._edge_table)
.group_by(self._edge_table.c[self._edge_target_key])
)

if where_clause is not None:
query = query.where(where_clause)

results = [x._asdict() for x in self._connection.execute(query).fetchall()]
results = {
r[self._edge_target_key]: r[1]
for r in self._connection.execute(query).fetchall()
for r in results
}

if nbunch and not isinstance(nbunch, (list, tuple)):
Expand Down
Loading
Loading