Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: rbac middleware #26347

Merged
merged 12 commits into from
Nov 27, 2024
194 changes: 194 additions & 0 deletions ee/api/rbac/access_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from typing import TYPE_CHECKING, cast


from rest_framework import exceptions, serializers, status
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet

from ee.models.rbac.access_control import AccessControl
from posthog.models.scopes import API_SCOPE_OBJECTS, APIScopeObjectOrNotSupported
from posthog.models.team.team import Team
from posthog.rbac.user_access_control import (
ACCESS_CONTROL_LEVELS_RESOURCE,
UserAccessControl,
default_access_level,
highest_access_level,
ordered_access_levels,
)


if TYPE_CHECKING:
_GenericViewSet = GenericViewSet
else:
_GenericViewSet = object


class AccessControlSerializer(serializers.ModelSerializer):
access_level = serializers.CharField(allow_null=True)

class Meta:
model = AccessControl
fields = [
"access_level",
"resource",
"resource_id",
"organization_member",
"role",
"created_by",
"created_at",
"updated_at",
]
read_only_fields = ["id", "created_at", "created_by"]

# Validate that resource is a valid option from the API_SCOPE_OBJECTS
def validate_resource(self, resource):
if resource not in API_SCOPE_OBJECTS:
raise serializers.ValidationError("Invalid resource. Must be one of: {}".format(API_SCOPE_OBJECTS))

return resource

# Validate that access control is a valid option
def validate_access_level(self, access_level):
if access_level and access_level not in ordered_access_levels(self.initial_data["resource"]):
raise serializers.ValidationError(
f"Invalid access level. Must be one of: {', '.join(ordered_access_levels(self.initial_data['resource']))}"
)

return access_level

def validate(self, data):
context = self.context

# Ensure that only one of organization_member or role is set
if data.get("organization_member") and data.get("role"):
raise serializers.ValidationError("You can not scope an access control to both a member and a role.")

access_control = cast(UserAccessControl, self.context["view"].user_access_control)
resource = data["resource"]
resource_id = data.get("resource_id")

# We assume the highest level is required for the given resource to edit access controls
required_level = highest_access_level(resource)
team = context["view"].team
the_object = context["view"].get_object()

if resource_id:
# Check that they have the right access level for this specific resource object
if not access_control.check_can_modify_access_levels_for_object(the_object):
raise exceptions.PermissionDenied(f"Must be {required_level} to modify {resource} permissions.")
else:
# If modifying the base resource rules then we are checking the parent membership (project or organization)
# NOTE: Currently we only support org level in the UI so its simply an org level check
if not access_control.check_can_modify_access_levels_for_object(team):
raise exceptions.PermissionDenied("Must be an Organization admin to modify project-wide permissions.")

return data


class AccessControlViewSetMixin(_GenericViewSet):
"""
Adds an "access_controls" action to the viewset that handles access control for the given resource

Why a mixin? We want to easily add this to any existing resource, including providing easy helpers for adding access control info such
as the current users access level to any response.
"""

# 1. Know that the project level access is covered by the Permission check
# 2. Get the actual object which we can pass to the serializer to check if the user created it
# 3. We can also use the serializer to check the access level for the object

def _get_access_control_serializer(self, *args, **kwargs):
kwargs.setdefault("context", self.get_serializer_context())
return AccessControlSerializer(*args, **kwargs)

def _get_access_controls(self, request: Request, is_global=False):
resource = cast(APIScopeObjectOrNotSupported, getattr(self, "scope_object", None))
user_access_control = cast(UserAccessControl, self.user_access_control) # type: ignore
team = cast(Team, self.team) # type: ignore

if is_global and resource != "project" or not resource or resource == "INTERNAL":
raise exceptions.NotFound("Role based access controls are only available for projects.")

obj = self.get_object()
resource_id = obj.id

if is_global:
# If role based then we are getting all controls for the project that aren't specific to a resource
access_controls = AccessControl.objects.filter(team=team, resource_id=None).all()
else:
# Otherwise we are getting all controls for the specific resource
access_controls = AccessControl.objects.filter(team=team, resource=resource, resource_id=resource_id).all()

serializer = self._get_access_control_serializer(instance=access_controls, many=True)
user_access_level = user_access_control.access_level_for_object(obj, resource)

return Response(
{
"access_controls": serializer.data,
# NOTE: For Role based controls we are always configuring resource level items
"available_access_levels": ACCESS_CONTROL_LEVELS_RESOURCE
if is_global
else ordered_access_levels(resource),
"default_access_level": "editor" if is_global else default_access_level(resource),
"user_access_level": user_access_level,
"user_can_edit_access_levels": user_access_control.check_can_modify_access_levels_for_object(obj),
}
)

def _update_access_controls(self, request: Request, is_global=False):
resource = getattr(self, "scope_object", None)
obj = self.get_object()
resource_id = str(obj.id)
team = cast(Team, self.team) # type: ignore

# Generically validate the incoming data
if not is_global:
# If not role based we are deriving from the viewset
data = request.data
data["resource"] = resource
data["resource_id"] = resource_id

partial_serializer = self._get_access_control_serializer(data=request.data)
partial_serializer.is_valid(raise_exception=True)
params = partial_serializer.validated_data

instance = AccessControl.objects.filter(
team=team,
resource=params["resource"],
resource_id=params.get("resource_id"),
organization_member=params.get("organization_member"),
role=params.get("role"),
).first()

if params["access_level"] is None:
if instance:
instance.delete()
return Response(status=status.HTTP_204_NO_CONTENT)

# Perform the upsert
if instance:
serializer = self._get_access_control_serializer(instance, data=request.data)
else:
serializer = self._get_access_control_serializer(data=request.data)

serializer.is_valid(raise_exception=True)
serializer.validated_data["team"] = team
serializer.save()

return Response(serializer.data, status=status.HTTP_200_OK)

@action(methods=["GET", "PUT"], detail=True)
def access_controls(self, request: Request, *args, **kwargs):
if request.method == "PUT":
return self._update_access_controls(request)

return self._get_access_controls(request)

@action(methods=["GET", "PUT"], detail=True)
def global_access_controls(self, request: Request, *args, **kwargs):
if request.method == "PUT":
return self._update_access_controls(request, is_global=True)

return self._get_access_controls(request, is_global=True)
25 changes: 2 additions & 23 deletions ee/api/rbac/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
from rest_framework import mixins, serializers, viewsets
from rest_framework.permissions import SAFE_METHODS, BasePermission

from ee.models.feature_flag_role_access import FeatureFlagRoleAccess
from ee.models.rbac.organization_resource_access import OrganizationResourceAccess
from ee.models.rbac.role import Role, RoleMembership
from posthog.api.organization_member import OrganizationMemberSerializer
from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.api.shared import UserBasicSerializer
from posthog.models import OrganizationMembership
from posthog.models.feature_flag import FeatureFlag
from posthog.models.user import User


Expand All @@ -38,7 +36,6 @@ def has_permission(self, request, view):
class RoleSerializer(serializers.ModelSerializer):
created_by = UserBasicSerializer(read_only=True)
members = serializers.SerializerMethodField()
associated_flags = serializers.SerializerMethodField()

class Meta:
model = Role
Expand All @@ -49,7 +46,6 @@ class Meta:
"created_at",
"created_by",
"members",
"associated_flags",
]
read_only_fields = ["id", "created_at", "created_by"]

Expand All @@ -75,29 +71,12 @@ def get_members(self, role: Role):
members = RoleMembership.objects.filter(role=role)
return RoleMembershipSerializer(members, many=True).data

def get_associated_flags(self, role: Role):
associated_flags: list[dict] = []

role_access_objects = FeatureFlagRoleAccess.objects.filter(role=role).values_list("feature_flag_id")
flags = FeatureFlag.objects.filter(id__in=role_access_objects)
for flag in flags:
associated_flags.append({"id": flag.id, "key": flag.key})
return associated_flags


class RoleViewSet(
TeamAndOrgViewSetMixin,
mixins.ListModelMixin,
mixins.CreateModelMixin,
mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
viewsets.GenericViewSet,
):
class RoleViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
scope_object = "organization"
permission_classes = [RolePermissions]
serializer_class = RoleSerializer
queryset = Role.objects.all()
permission_classes = [RolePermissions]

def safely_get_queryset(self, queryset):
return queryset.filter(**self.request.GET.dict())
Expand Down
Loading
Loading