Skip to content

Commit

Permalink
Allow specifying either user or user_id
Browse files Browse the repository at this point in the history
  • Loading branch information
singingwolfboy committed Feb 28, 2015
1 parent 0c72e88 commit 530fed2
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 16 deletions.
64 changes: 50 additions & 14 deletions flask_dance/consumer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def __init__(self, name, import_name,
)

self.user = None
self.user_id = None
self.set_token_storage_session()

self.logged_in_funcs = []
self.from_config = {}
self.before_app_request(self.load_config)
Expand Down Expand Up @@ -134,7 +136,8 @@ def set_token(token):
def delete_token():
del flask.session[key]

def set_token_storage_sqlalchemy(self, model, session, user=None, cache=None):
def set_token_storage_sqlalchemy(self, model, session,
user=None, user_id=None, cache=None):
"""
A helper method to set up the blueprint to store and retrieve OAuth
tokens using SQLAlchemy. This will overwrite any custom token
Expand All @@ -154,6 +157,9 @@ def set_token_storage_sqlalchemy(self, model, session, user=None, cache=None):
in user. This argument is optional; if not provided,
OAuth tokens will not be associated with specific users in
your application.
user_id: The ID of the current logged in user, if any. This argument
is optional, and is used in the same way as the ``user`` argument.
You do not need to specify both.
cache: An instance of `Flask-Cache`_. This is optional, but highly
recommended for performance reasons.
Expand All @@ -166,19 +172,29 @@ def set_token_storage_sqlalchemy(self, model, session, user=None, cache=None):

bp = self
outer_user = user
outer_user_id = user_id

def make_cache_key(name=None, user=None):
u = first(_get_real_user(ref) for ref in (user, outer_user, bp.user))
return "flask_dance_token|{name}|{user}".format(
name=self.name, user=getattr(u, "id", u),
def make_cache_key(name=None, user=None, user_id=None):
uid = first([user_id, outer_user_id, bp.user_id])
if not uid:
u = first(_get_real_user(ref) for ref in (user, outer_user, bp.user))
uid = getattr(u, "id", u)
return "flask_dance_token|{name}|{user_id}".format(
name=self.name, user_id=uid,
)

@cache.memoize()
def get_token(user=None):
def get_token(user=None, user_id=None):
query = session.query(model).filter_by(provider=self.name)
if hasattr(model, "user"):
u = first(_get_real_user(ref) for ref in (user, outer_user, bp.user))
# check for user ID
uid = first([user_id, outer_user_id, bp.user_id])
if hasattr(model, "user_id") and uid:
query = query.filter_by(user_id=uid)
# check for user (relationship property)
u = first(_get_real_user(ref) for ref in (user, outer_user, bp.user))
if hasattr(model, "user") and u:
query = query.filter_by(user=u)
# run query
try:
return query.one().token
except NoResultFound:
Expand All @@ -187,20 +203,31 @@ def get_token(user=None):
self.token_getter(get_token)

@self.token_setter
def set_token(token, user=None):
has_user = hasattr(model, "user")
def set_token(token, user=None, user_id=None):
# if there was an existing model, delete it
existing_query = session.query(model).filter_by(provider=self.name)
# check for user ID
has_user_id = hasattr(model, "user_id")
if has_user_id:
uid = first([user_id, outer_user_id, bp.user_id])
if uid:
existing_query = existing_query.filter_by(user_id=uid)
# check for user (relationship property)
has_user = hasattr(model, "user")
if has_user:
u = first(_get_real_user(ref) for ref in (user, outer_user, bp.user))
existing_query = existing_query.filter_by(user=u)
if u:
existing_query = existing_query.filter_by(user=u)
# queue up delete query -- won't be run until commit()
existing_query.delete()
# create a new model for this token
kwargs = {
"provider": self.name,
"token": token,
}
if has_user:
if has_user_id and uid:
kwargs["user_id"] = uid
if has_user and u:
kwargs["user"] = u
session.add(model(**kwargs))
# commit to delete and add simultaneously
Expand All @@ -209,12 +236,21 @@ def set_token(token, user=None):
cache.delete_memoized(self.get_token)

@self.token_deleter
def delete_token(user=None):
def delete_token(user=None, user_id=None):
query = session.query(model).filter_by(provider=self.name)
# check for user ID
if hasattr(model, "user_id"):
uid = first([user_id, outer_user_id, bp.user_id])
if uid:
query = query.filter_by(user_id=uid)
# check for user (relationship property)
if hasattr(model, "user"):
u = first(_get_real_user(ref) for ref in (user, outer_user, bp.user))
query = query.filter_by(user=u)
if u:
query = query.filter_by(user=u)
# run query
query.delete()
session.commit()
# invalidate cache
cache.delete_memoized(self.get_token)

Expand Down
3 changes: 2 additions & 1 deletion flask_dance/consumer/oauth1.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ def prepare_request(self, request):
request.url = self.base_url.relative(request.url)
return super(OAuth1Session, self).prepare_request(request)

def load_token(self, user):
def load_token(self, user=None, user_id=None):
if self.blueprint:
self.blueprint.user = user
self.blueprint.user_id = user_id
self.blueprint.load_token()

# backwards compatibility
Expand Down
3 changes: 2 additions & 1 deletion flask_dance/consumer/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ def request(self, method, url, data=None, headers=None, **kwargs):
method=method, url=url, data=data, headers=headers, **kwargs
)

def load_token(self, user):
def load_token(self, user=None, user_id=None):
if self.blueprint:
self.blueprint.user = user
self.blueprint.user_id = user_id
self.blueprint.load_token()

# backwards compatibility
Expand Down
4 changes: 4 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ def done():
sess.load_token(sue)
assert sess.token == sue_token

# load for user ID as well
sess.load_token(user_id=bob.id)
assert sess.token == bob_token

def test_model_with_flask_login(app, db, blueprint, request):
login_manager = LoginManager(app)

Expand Down

0 comments on commit 530fed2

Please sign in to comment.