Skip to content

Commit

Permalink
allow get_user_groups / auth_state_groups_key to be async
Browse files Browse the repository at this point in the history
  • Loading branch information
minrk committed Aug 23, 2024
1 parent cd00a8d commit 6247c6e
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions oauthenticator/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import uuid
from functools import reduce
from inspect import isawaitable
from urllib.parse import quote, urlencode, urlparse, urlunparse

import jwt
Expand Down Expand Up @@ -361,6 +362,7 @@ class OAuthenticator(Authenticator):
Can be a string key name (use periods for nested keys), or a callable
that accepts the auth state (as a dict) and returns the groups list.
Callables may be async.
Requires `manage_groups` to also be `True`.
""",
Expand Down Expand Up @@ -1086,18 +1088,22 @@ def build_auth_state_dict(self, token_info, user_info):
self.user_auth_state_key: user_info,
}

def get_user_groups(self, auth_state: dict):
async def get_user_groups(self, auth_state: dict):
"""
Returns a set of groups the user belongs to based on auth_state_groups_key
and provided auth_state.
- If auth_state_groups_key is a callable, it returns the list of groups directly.
Callable may be async.
- If auth_state_groups_key is a nested dictionary key like
"permissions.groups", this function returns
auth_state["permissions"]["groups"].
"""
if callable(self.auth_state_groups_key):
return set(self.auth_state_groups_key(auth_state))
groups = self.auth_state_groups_key(auth_state)
if isawaitable(groups):
groups = await groups
return set(groups)
try:
return set(
reduce(dict.get, self.auth_state_groups_key.split("."), auth_state)
Expand Down Expand Up @@ -1126,6 +1132,8 @@ async def update_auth_model(self, auth_model):
if self.manage_groups:
auth_state = auth_model["auth_state"]
user_groups = self.get_user_groups(auth_state)
if isawaitable(user_groups):
user_groups = await user_groups

auth_model["groups"] = sorted(user_groups)

Expand Down Expand Up @@ -1223,6 +1231,8 @@ async def check_allowed(self, username, auth_model):
if self.manage_groups and self.allowed_groups:
auth_state = auth_model["auth_state"]
user_groups = self.get_user_groups(auth_state)
if isawaitable(user_groups):
user_groups = await user_groups
if any(user_groups & self.allowed_groups):
return True

Expand Down

0 comments on commit 6247c6e

Please sign in to comment.