Skip to content

Commit

Permalink
Added helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
maximebf committed Jun 11, 2024
1 parent 60fc72e commit 7992b3e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 19 deletions.
19 changes: 0 additions & 19 deletions sqlorm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import inspect
import urllib.parse
import functools
from blinker import Namespace
from .sql import render, ParametrizedStmt
from .resultset import ResultSet, CompositeResultSet, CompositionMap
Expand Down Expand Up @@ -579,21 +578,3 @@ def _signal_rv(signal_rv):
if rv:
final_rv = rv
return final_rv


def connect_via_engine(engine, signal, func=None):
def decorator(func):
@functools.wraps(func)
def wrapper(sender, **kw):
matches = False
if isinstance(sender, Engine):
matches = sender is engine
elif isinstance(sender, Session):
matches = sender.engine is engine
elif isinstance(sender, Transaction):
matches = sender.session.engine is engine
if matches:
return func(sender, **kw)
signal.connect(wrapper, weak=False)
return wrapper
return decorator(func) if func else decorator
32 changes: 32 additions & 0 deletions sqlorm/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from .engine import Engine, Session, Transaction, ensure_session
import functools


def connect_via_engine(engine, signal, func=None):
def decorator(func):
@functools.wraps(func)
def wrapper(sender, **kw):
matches = False
if isinstance(sender, Engine):
matches = sender is engine
elif isinstance(sender, Session):
matches = sender.engine is engine
elif isinstance(sender, Transaction):
matches = sender.session.engine is engine
if matches:
return func(sender, **kw)
signal.connect(wrapper, weak=False)
return wrapper
return decorator(func) if func else decorator


def after_commit(func, *args, **kwargs):
with ensure_session() as session:
def rollback_listener(session):
Session.after_commit.disconnect(commit_listener)
Session.after_rollback.disconnect(rollback_listener)
def commit_listener(session):
func(*args, **kwargs)
rollback_listener(session)
Session.after_commit.connect(commit_listener, session, weak=False)
Session.after_rollback.connect(rollback_listener, session, weak=False)

0 comments on commit 7992b3e

Please sign in to comment.