From ebc567a7e556cd8bcaf8c7ffe5d1ac5d53352fbb Mon Sep 17 00:00:00 2001 From: Phil Weir Date: Tue, 23 Jan 2024 21:28:41 +0000 Subject: [PATCH] cover case where a user creates a resource of their own, introducing a 'principal user' (naive implementation) --- coral/management/commands/apply_sets.py | 125 +---------------- .../commands/print_permissions_table.py | 1 + coral/permissions/casbin.py | 62 ++++++--- coral/utils/casbin.py | 130 ++++++++++++++++++ coral/wsgi.py | 16 ++- 5 files changed, 198 insertions(+), 136 deletions(-) diff --git a/coral/management/commands/apply_sets.py b/coral/management/commands/apply_sets.py index 4df6e4f50..eff9ac6df 100644 --- a/coral/management/commands/apply_sets.py +++ b/coral/management/commands/apply_sets.py @@ -16,22 +16,10 @@ along with this program. If not, see . """ -from urllib.parse import parse_qs import logging -import time import readline from django.core.management.base import BaseCommand from django.contrib.auth.models import User, Group as DjangoGroup -from arches.app.utils.permission_backend import assign_perm -from arches.app.models.system_settings import settings -from arches.app.models.resource import Resource -from arches_orm.models import Set, LogicalSet -from coral.permissions.casbin import CasbinPermissionFramework -from arches.app.search.components.base import SearchFilterFactory -from arches.app.views.search import build_search -from arches.app.search.elasticsearch_dsl_builder import Bool, Match, Query, Nested, Terms, MaxAgg, Aggregation, UpdateByQuery -from arches.app.search.search_engine_factory import SearchEngineFactory -from arches.app.search.mappings import RESOURCES_INDEX from coral import settings @@ -55,115 +43,14 @@ def add_arguments(self, parser): def handle(self, *args, **options): - self.print_statistics = True if options["print_statistics"] else False + # Cannot be imported until Django ready + from coral.permissions.casbin import CasbinPermissionFramework + from coral.utils.casbin import SetApplicator - table = self.apply_sets() + print_statistics = True if options["print_statistics"] else False - def _apply_set(self, se, set_id, set_query): - results = [] - for add_not_remove in (True, False): - dsl = Query(se=se) - bool_query = Bool() - if add_not_remove: - bool_query.must_not(set_query()) - bool_query.must(Nested(path="sets", query=Terms(field="sets.id", terms=[str(set_id)]))) - sets = [str(set_id)] - source = """ - if (ctx._source.sets != null) { - for (int i=ctx._source.sets.length-1; i>=0; i--) { - if (params.logicalSets.contains(ctx._source.sets[i].id)) { - ctx._source.sets.remove(i); - } - } - } - """ - else: - bool_query.must(set_query()) - bool_query.must_not(Nested(path="sets", query=Terms(field="sets.id", terms=[str(set_id)]))) - source = "ctx._source.sets.addAll(params.logicalSets)" - sets = [{"id": str(set_id)}] - dsl.add_query(bool_query) - update_by_query = UpdateByQuery(se=se, query=dsl, script={ - "lang": "painless", - "source": source, - "params": { - "logicalSets": sets - } - }) - results.append(update_by_query.run(index=RESOURCES_INDEX, wait_for_completion=False)) - return results - - def apply_sets(self, resourceinstanceid=None): - """Apply set mappings to resources. - - Run update-by-queries to mark/unmark sets against resources in Elasticsearch. - """ - - from arches.app.search.search_engine_factory import SearchEngineInstance as _se - - logical_sets = LogicalSet.all() - results = [] - print("Print statistics?", self.print_statistics) - for logical_set in logical_sets: - if logical_set.member_definition: - # user=True is shorthand for "do not restrict by user" - parameters = parse_qs(logical_set.member_definition) - for key, value in parameters.items(): - if len(value) != 1: - raise RuntimeError("Each filter type must appear exactly once") - parameters[key] = value[0] - def _logical_set_query(): - _, _, inner_dsl = build_search( - for_export=False, - pages=False, - total=None, - resourceinstanceid=None, - load_tiles=False, - user=True, - provisional_filter=[], - parameters=parameters, - permitted_nodegroups=True # This should be ignored as user==True - ) - return inner_dsl.dsl["query"] - if self.print_statistics: - dsl = Query(se=_se) - dsl.add_query(_logical_set_query()) - count = dsl.count(index=RESOURCES_INDEX) - print("Logical Set:", logical_set.id) - print("Definition:", logical_set.member_definition) - print("Count:", count) - results = self._apply_set(_se, f"l:{logical_set.id}", _logical_set_query) - self.wait_for_completion(_se, results) - if self.print_statistics: - dsl = Query(se=_se) - dsl.add_query(Nested(path="sets", query=Terms(field="sets.id", terms=[f"l:{logical_set.id}"]))) - count = dsl.count(index=RESOURCES_INDEX) - print("Applies to by search:", count) - - sets = Set.all() - for regular_set in sets: - if regular_set.members: - # user=True is shorthand for "do not restrict by user" - def _regular_set_query(): - query = Query(se=_se) - bool_query = Bool() - bool_query.must(Terms(field="_id", terms=[str(member.id) for member in regular_set.members])) - query.add_query(bool_query) - return query.dsl["query"] - results = self._apply_set(_se, f"r:{regular_set.id}", _regular_set_query) - self.wait_for_completion(_se, results) + set_applicator = SetApplicator(print_statistics=print_statistics, wait_for_completion=True) + set_applicator.apply_sets() framework = CasbinPermissionFramework() framework.recalculate_table() - - def wait_for_completion(self, _se, results): - tasks_client = _se.make_tasks_client() - while results: - result = results[0] - task_id = result["task"] - status = tasks_client.get(task_id=task_id) - if status["completed"]: - results.remove(result) - else: - print(task_id, "not yet completed") - time.sleep(0.5) diff --git a/coral/management/commands/print_permissions_table.py b/coral/management/commands/print_permissions_table.py index 5af526238..3561a8e8f 100644 --- a/coral/management/commands/print_permissions_table.py +++ b/coral/management/commands/print_permissions_table.py @@ -49,6 +49,7 @@ def get_table(self): enforcer.model.logger.setLevel(logging.INFO) framework.recalculate_table() enforcer.model.print_policy() + print(enforcer.get_implicit_permissions_for_user("u:310")) #group_tree = {} #set_tree = {} #group_x_set = [] diff --git a/coral/permissions/casbin.py b/coral/permissions/casbin.py index 4b1ca461c..13c9ef59a 100644 --- a/coral/permissions/casbin.py +++ b/coral/permissions/casbin.py @@ -26,7 +26,7 @@ from arches.app.search.elasticsearch_dsl_builder import Query from arches.app.search.mappings import RESOURCES_INDEX from arches.app.utils.permission_backend import PermissionFramework, NotUserNorGroup as ArchesNotUserNorGroup -from arches.app.permissions.arches_standard import get_nodegroups_by_perm_for_user_or_group +from arches.app.permissions.arches_standard import get_nodegroups_by_perm_for_user_or_group, assign_perm from arches_orm.models import Person, Organization, Set, LogicalSet, Group from arches_orm.wrapper import ResourceWrapper @@ -38,6 +38,14 @@ "70415d03-b11b-48a6-b989-933d788ffc88": ["view_resourceinstance"], "45d54859-bf3c-48f2-a387-55a0050ff572": ["execute_resourceinstance"], } +GRAPH_REMAPPINGS = { + "809598ac-6dc5-498e-a7af-52b1381942a4": "models.write_nodegroup", + "33a0218b-b1cc-42d8-9a79-31a6b2147893": "models.write_nodegroup", + "70415d03-b11b-48a6-b989-933d788ffc88": "models.read_nodegroup", + "45d54859-bf3c-48f2-a387-55a0050ff572": "models.write_nodegroup", +} +REV_GRAPH_REMAPPINGS = {v: k for k, v in GRAPH_REMAPPINGS.items()} +RESOURCE_TO_GRAPH_REMAPPINGS = {v[0]: GRAPH_REMAPPINGS[k] for k, v in REMAPPINGS.items()} class NoSubjectError(RuntimeError): @@ -300,28 +308,39 @@ def remove_perm(self, perm, user_or_group=None, obj=None): # This is slow and should be avoided where possible. def get_perms(self, user_or_group, obj): - return { - act for sub, tobj, act in - self._get_perms(user_or_group, obj) - } + perms = set() + for sub, tobj, act in self._get_perms(user_or_group, obj): + perms |= set(REMAPPINGS.get(act, act)) + return perms def get_group_perms(self, user_or_group, obj): # FIXME: what should this do if a group is passed? - return { - act for sub, tobj, act in - self._get_perms(user_or_group, obj) - if sub != f"u:{user_or_group.pk}" - } + perms = set() + for sub, tobj, act in self._get_perms(user_or_group, obj): + if sub != f"u:{user_or_group.pk}": + perms |= set(REMAPPINGS.get(act, act)) + return perms def get_user_perms(self, user, obj): - return { - act for sub, tobj, act in - self._get_perms(user, obj) - if sub == f"u:{user.pk}" - } + perms = set() + for sub, tobj, act in self._get_perms(user, obj): + if sub == f"u:{user.pk}": + perms |= set(REMAPPINGS.get(act, act)) + return perms def _get_perms(self, user_or_group, obj): if obj is not None: + if isinstance(obj, ResourceInstance): + if isinstance(user_or_group, User) and user_or_group.id and obj.principaluser_id and int(user_or_group.id) == int(obj.principaluser_id): + return { + resource_perm for perm, resource_perm in REV_GRAPH_REMAPPINGS.items() + if self.user_has_resource_model_permissions( + user_or_group, + [perm], + obj + ) + } + obj = self._obj_to_str(obj) user_or_group = self._subj_to_str(user_or_group) @@ -482,8 +501,17 @@ def check_resource_instance_permissions(self, user, resourceid, permission): try: resource = Resource(resourceinstanceid=resourceid) - result["resource"] = resource index = resource.get_index() + if (principal_users := index.get("_source", {}).get("permissions", {}).get("principal_user", [])): + if len(principal_users) >= 1 and user and user.id in principal_users: + if permission == "view_resourceinstance" and self.user_has_resource_model_permissions(user, ["models.read_nodegroup"], resource): + result["permitted"] = True + return result + elif user.groups.filter(name__in=settings.RESOURCE_EDITOR_GROUPS).exists() or self.user_can_edit_model_nodegroups( + user, resource + ): + result["permitted"] = True + return result sets = [ st.get("id") for st in index.get("_source", {}).get("sets", {}) ] @@ -669,6 +697,7 @@ def user_has_resource_model_permissions(self, user, perms, resource): if user.is_superuser: return True groups = self._enforcer.get_implicit_users_for_resource(f"gp:{resource.graph_id}") + print(groups, "GROUPS") group_ids = { group[3:] for group, _, act in groups if group.startswith("dg:") and @@ -709,6 +738,7 @@ def get_resource_types_by_perm(self, user, perms): allowed = set() subj = self._subj_to_str(user) graphs = self._enforcer.get_implicit_permissions_for_user(subj) + print(graphs, "GRAPHS", subj) permissioned_graphs = set() for _, graph, act in graphs: if not graph.startswith("gp:"): diff --git a/coral/utils/casbin.py b/coral/utils/casbin.py index e69de29bb..b503bf5cd 100644 --- a/coral/utils/casbin.py +++ b/coral/utils/casbin.py @@ -0,0 +1,130 @@ +from arches.app.search.components.base import SearchFilterFactory +from urllib.parse import parse_qs +from arches.app.views.search import build_search +from arches.app.search.elasticsearch_dsl_builder import Bool, Match, Query, Ids, Nested, Terms, MaxAgg, Aggregation, UpdateByQuery +from arches.app.search.search_engine_factory import SearchEngineFactory +from arches.app.search.mappings import RESOURCES_INDEX +from arches.app.models.resource import Resource +from arches_orm.models import Set, LogicalSet +import time + +class SetApplicator: + def __init__(self, print_statistics, wait_for_completion): + self.print_statistics = print_statistics + self.wait = wait_for_completion + + def _apply_set(self, se, set_id, set_query, resourceinstanceid=None): + results = [] + for add_not_remove in (True, False): + dsl = Query(se=se) + bool_query = Bool() + if resourceinstanceid: + bool_query.must(Ids(ids=[str(resourceinstanceid)])) + if add_not_remove: + bool_query.must_not(set_query()) + bool_query.must(Nested(path="sets", query=Terms(field="sets.id", terms=[str(set_id)]))) + sets = [str(set_id)] + source = """ + if (ctx._source.sets != null) { + for (int i=ctx._source.sets.length-1; i>=0; i--) { + if (params.logicalSets.contains(ctx._source.sets[i].id)) { + ctx._source.sets.remove(i); + } + } + } + """ + else: + bool_query.must(set_query()) + bool_query.must_not(Nested(path="sets", query=Terms(field="sets.id", terms=[str(set_id)]))) + source = "ctx._source.sets.addAll(params.logicalSets)" + sets = [{"id": str(set_id)}] + dsl.add_query(bool_query) + update_by_query = UpdateByQuery(se=se, query=dsl, script={ + "lang": "painless", + "source": source, + "params": { + "logicalSets": sets + } + }) + results.append(update_by_query.run(index=RESOURCES_INDEX, wait_for_completion=False)) + return results + + def apply_sets(self, resourceinstanceid=None): + """Apply set mappings to resources. + + Run update-by-queries to mark/unmark sets against resources in Elasticsearch. + """ + + from arches.app.search.search_engine_factory import SearchEngineInstance as _se + + logical_sets = LogicalSet.all() + results = [] + print("Print statistics?", self.print_statistics) + for logical_set in logical_sets: + if logical_set.member_definition: + # user=True is shorthand for "do not restrict by user" + parameters = parse_qs(logical_set.member_definition) + for key, value in parameters.items(): + if len(value) != 1: + raise RuntimeError("Each filter type must appear exactly once") + parameters[key] = value[0] + def _logical_set_query(): + _, _, inner_dsl = build_search( + for_export=False, + pages=False, + total=None, + resourceinstanceid=None, + load_tiles=False, + user=True, + provisional_filter=[], + parameters=parameters, + permitted_nodegroups=True # This should be ignored as user==True + ) + return inner_dsl.dsl["query"] + if self.print_statistics: + dsl = Query(se=_se) + dsl.add_query(_logical_set_query()) + count = dsl.count(index=RESOURCES_INDEX) + print("Logical Set:", logical_set.id) + print("Definition:", logical_set.member_definition) + print("Count:", count) + results = self._apply_set(_se, f"l:{logical_set.id}", _logical_set_query, resourceinstanceid=resourceinstanceid) + if self.wait: + self.wait_for_completion(_se, results) + if self.print_statistics: + dsl = Query(se=_se) + bool_query = Bool() + bool_query.must(Nested(path="sets", query=Terms(field="sets.id", terms=[f"l:{logical_set.id}"]))) + if resourceinstanceid: + bool_query.must(Ids(ids=[str(resourceinstanceid)])) + dsl.add_query(bool_query) + count = dsl.count(index=RESOURCES_INDEX) + print("Applies to by search:", count) + + sets = Set.all() + for regular_set in sets: + if regular_set.members: + # user=True is shorthand for "do not restrict by user" + members = [str(member.id) for member in regular_set.members] + if not resourceinstanceid or str(resourceinstanceid) in members: + def _regular_set_query(): + query = Query(se=_se) + bool_query = Bool() + bool_query.must(Terms(field="_id", terms=members)) + query.add_query(bool_query) + return query.dsl["query"] + results = self._apply_set(_se, f"r:{regular_set.id}", _regular_set_query, resourceinstanceid=resourceinstanceid) + if self.wait: + self.wait_for_completion(_se, results) + + def wait_for_completion(self, _se, results): + tasks_client = _se.make_tasks_client() + while results: + result = results[0] + task_id = result["task"] + status = tasks_client.get(task_id=task_id) + if status["completed"]: + results.remove(result) + else: + print(task_id, "not yet completed") + time.sleep(0.1) diff --git a/coral/wsgi.py b/coral/wsgi.py index 8ecaed269..e3f289aa7 100644 --- a/coral/wsgi.py +++ b/coral/wsgi.py @@ -16,6 +16,7 @@ along with this program. If not, see . ''' +import threading import os import sys import inspect @@ -30,13 +31,26 @@ os.environ['DJANGO_SETTINGS_MODULE'] = "coral.settings" from django.core.wsgi import get_wsgi_application +from django.dispatch import receiver +from arches.app.models.resource import resource_indexed + application = get_wsgi_application() from arches.app.models.system_settings import settings settings.update_from_db() +@receiver(resource_indexed) +def update_permissions(sender, instance, **kwargs): + from coral.utils.casbin import SetApplicator + # This may run too quickly + # Instead, it should trigger a (debounced) recalc. + # This may still require delays _between_ the upserts also. + def _exec(): + set_applicator = SetApplicator(print_statistics=True, wait_for_completion=True) + set_applicator.apply_sets(resourceinstanceid=instance.resourceinstanceid) + threading.Timer(3.0, _exec).start() + if os.getenv("CASBIN_LISTEN", False): - import threading from coral.permissions.casbin import trigger t = threading.Thread(target=trigger.listen) t.setDaemon(True)