Skip to content

Commit

Permalink
cover case where a user creates a resource of their own, introducing …
Browse files Browse the repository at this point in the history
…a 'principal user' (naive implementation)
  • Loading branch information
philtweir committed Jan 23, 2024
1 parent 6d047bf commit ebc567a
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 136 deletions.
125 changes: 6 additions & 119 deletions coral/management/commands/apply_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,10 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""

from urllib.parse import parse_qs
import logging
import time
import readline
from django.core.management.base import BaseCommand
from django.contrib.auth.models import User, Group as DjangoGroup
from arches.app.utils.permission_backend import assign_perm
from arches.app.models.system_settings import settings
from arches.app.models.resource import Resource
from arches_orm.models import Set, LogicalSet
from coral.permissions.casbin import CasbinPermissionFramework
from arches.app.search.components.base import SearchFilterFactory
from arches.app.views.search import build_search
from arches.app.search.elasticsearch_dsl_builder import Bool, Match, Query, Nested, Terms, MaxAgg, Aggregation, UpdateByQuery
from arches.app.search.search_engine_factory import SearchEngineFactory
from arches.app.search.mappings import RESOURCES_INDEX

from coral import settings

Expand All @@ -55,115 +43,14 @@ def add_arguments(self, parser):


def handle(self, *args, **options):
self.print_statistics = True if options["print_statistics"] else False
# Cannot be imported until Django ready
from coral.permissions.casbin import CasbinPermissionFramework
from coral.utils.casbin import SetApplicator

table = self.apply_sets()
print_statistics = True if options["print_statistics"] else False

def _apply_set(self, se, set_id, set_query):
results = []
for add_not_remove in (True, False):
dsl = Query(se=se)
bool_query = Bool()
if add_not_remove:
bool_query.must_not(set_query())
bool_query.must(Nested(path="sets", query=Terms(field="sets.id", terms=[str(set_id)])))
sets = [str(set_id)]
source = """
if (ctx._source.sets != null) {
for (int i=ctx._source.sets.length-1; i>=0; i--) {
if (params.logicalSets.contains(ctx._source.sets[i].id)) {
ctx._source.sets.remove(i);
}
}
}
"""
else:
bool_query.must(set_query())
bool_query.must_not(Nested(path="sets", query=Terms(field="sets.id", terms=[str(set_id)])))
source = "ctx._source.sets.addAll(params.logicalSets)"
sets = [{"id": str(set_id)}]
dsl.add_query(bool_query)
update_by_query = UpdateByQuery(se=se, query=dsl, script={
"lang": "painless",
"source": source,
"params": {
"logicalSets": sets
}
})
results.append(update_by_query.run(index=RESOURCES_INDEX, wait_for_completion=False))
return results

def apply_sets(self, resourceinstanceid=None):
"""Apply set mappings to resources.
Run update-by-queries to mark/unmark sets against resources in Elasticsearch.
"""

from arches.app.search.search_engine_factory import SearchEngineInstance as _se

logical_sets = LogicalSet.all()
results = []
print("Print statistics?", self.print_statistics)
for logical_set in logical_sets:
if logical_set.member_definition:
# user=True is shorthand for "do not restrict by user"
parameters = parse_qs(logical_set.member_definition)
for key, value in parameters.items():
if len(value) != 1:
raise RuntimeError("Each filter type must appear exactly once")
parameters[key] = value[0]
def _logical_set_query():
_, _, inner_dsl = build_search(
for_export=False,
pages=False,
total=None,
resourceinstanceid=None,
load_tiles=False,
user=True,
provisional_filter=[],
parameters=parameters,
permitted_nodegroups=True # This should be ignored as user==True
)
return inner_dsl.dsl["query"]
if self.print_statistics:
dsl = Query(se=_se)
dsl.add_query(_logical_set_query())
count = dsl.count(index=RESOURCES_INDEX)
print("Logical Set:", logical_set.id)
print("Definition:", logical_set.member_definition)
print("Count:", count)
results = self._apply_set(_se, f"l:{logical_set.id}", _logical_set_query)
self.wait_for_completion(_se, results)
if self.print_statistics:
dsl = Query(se=_se)
dsl.add_query(Nested(path="sets", query=Terms(field="sets.id", terms=[f"l:{logical_set.id}"])))
count = dsl.count(index=RESOURCES_INDEX)
print("Applies to by search:", count)

sets = Set.all()
for regular_set in sets:
if regular_set.members:
# user=True is shorthand for "do not restrict by user"
def _regular_set_query():
query = Query(se=_se)
bool_query = Bool()
bool_query.must(Terms(field="_id", terms=[str(member.id) for member in regular_set.members]))
query.add_query(bool_query)
return query.dsl["query"]
results = self._apply_set(_se, f"r:{regular_set.id}", _regular_set_query)
self.wait_for_completion(_se, results)
set_applicator = SetApplicator(print_statistics=print_statistics, wait_for_completion=True)
set_applicator.apply_sets()

framework = CasbinPermissionFramework()
framework.recalculate_table()

def wait_for_completion(self, _se, results):
tasks_client = _se.make_tasks_client()
while results:
result = results[0]
task_id = result["task"]
status = tasks_client.get(task_id=task_id)
if status["completed"]:
results.remove(result)
else:
print(task_id, "not yet completed")
time.sleep(0.5)
1 change: 1 addition & 0 deletions coral/management/commands/print_permissions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def get_table(self):
enforcer.model.logger.setLevel(logging.INFO)
framework.recalculate_table()
enforcer.model.print_policy()
print(enforcer.get_implicit_permissions_for_user("u:310"))
#group_tree = {}
#set_tree = {}
#group_x_set = []
Expand Down
62 changes: 46 additions & 16 deletions coral/permissions/casbin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from arches.app.search.elasticsearch_dsl_builder import Query
from arches.app.search.mappings import RESOURCES_INDEX
from arches.app.utils.permission_backend import PermissionFramework, NotUserNorGroup as ArchesNotUserNorGroup
from arches.app.permissions.arches_standard import get_nodegroups_by_perm_for_user_or_group
from arches.app.permissions.arches_standard import get_nodegroups_by_perm_for_user_or_group, assign_perm

from arches_orm.models import Person, Organization, Set, LogicalSet, Group
from arches_orm.wrapper import ResourceWrapper
Expand All @@ -38,6 +38,14 @@
"70415d03-b11b-48a6-b989-933d788ffc88": ["view_resourceinstance"],
"45d54859-bf3c-48f2-a387-55a0050ff572": ["execute_resourceinstance"],
}
GRAPH_REMAPPINGS = {
"809598ac-6dc5-498e-a7af-52b1381942a4": "models.write_nodegroup",
"33a0218b-b1cc-42d8-9a79-31a6b2147893": "models.write_nodegroup",
"70415d03-b11b-48a6-b989-933d788ffc88": "models.read_nodegroup",
"45d54859-bf3c-48f2-a387-55a0050ff572": "models.write_nodegroup",
}
REV_GRAPH_REMAPPINGS = {v: k for k, v in GRAPH_REMAPPINGS.items()}
RESOURCE_TO_GRAPH_REMAPPINGS = {v[0]: GRAPH_REMAPPINGS[k] for k, v in REMAPPINGS.items()}


class NoSubjectError(RuntimeError):
Expand Down Expand Up @@ -300,28 +308,39 @@ def remove_perm(self, perm, user_or_group=None, obj=None):

# This is slow and should be avoided where possible.
def get_perms(self, user_or_group, obj):
return {
act for sub, tobj, act in
self._get_perms(user_or_group, obj)
}
perms = set()
for sub, tobj, act in self._get_perms(user_or_group, obj):
perms |= set(REMAPPINGS.get(act, act))
return perms

def get_group_perms(self, user_or_group, obj):
# FIXME: what should this do if a group is passed?
return {
act for sub, tobj, act in
self._get_perms(user_or_group, obj)
if sub != f"u:{user_or_group.pk}"
}
perms = set()
for sub, tobj, act in self._get_perms(user_or_group, obj):
if sub != f"u:{user_or_group.pk}":
perms |= set(REMAPPINGS.get(act, act))
return perms

def get_user_perms(self, user, obj):
return {
act for sub, tobj, act in
self._get_perms(user, obj)
if sub == f"u:{user.pk}"
}
perms = set()
for sub, tobj, act in self._get_perms(user, obj):
if sub == f"u:{user.pk}":
perms |= set(REMAPPINGS.get(act, act))
return perms

def _get_perms(self, user_or_group, obj):
if obj is not None:
if isinstance(obj, ResourceInstance):
if isinstance(user_or_group, User) and user_or_group.id and obj.principaluser_id and int(user_or_group.id) == int(obj.principaluser_id):
return {
resource_perm for perm, resource_perm in REV_GRAPH_REMAPPINGS.items()
if self.user_has_resource_model_permissions(
user_or_group,
[perm],
obj
)
}

obj = self._obj_to_str(obj)

user_or_group = self._subj_to_str(user_or_group)
Expand Down Expand Up @@ -482,8 +501,17 @@ def check_resource_instance_permissions(self, user, resourceid, permission):

try:
resource = Resource(resourceinstanceid=resourceid)
result["resource"] = resource
index = resource.get_index()
if (principal_users := index.get("_source", {}).get("permissions", {}).get("principal_user", [])):
if len(principal_users) >= 1 and user and user.id in principal_users:
if permission == "view_resourceinstance" and self.user_has_resource_model_permissions(user, ["models.read_nodegroup"], resource):
result["permitted"] = True
return result
elif user.groups.filter(name__in=settings.RESOURCE_EDITOR_GROUPS).exists() or self.user_can_edit_model_nodegroups(
user, resource
):
result["permitted"] = True
return result
sets = [
st.get("id") for st in index.get("_source", {}).get("sets", {})
]
Expand Down Expand Up @@ -669,6 +697,7 @@ def user_has_resource_model_permissions(self, user, perms, resource):
if user.is_superuser:
return True
groups = self._enforcer.get_implicit_users_for_resource(f"gp:{resource.graph_id}")
print(groups, "GROUPS")
group_ids = {
group[3:] for group, _, act in groups
if group.startswith("dg:") and
Expand Down Expand Up @@ -709,6 +738,7 @@ def get_resource_types_by_perm(self, user, perms):
allowed = set()
subj = self._subj_to_str(user)
graphs = self._enforcer.get_implicit_permissions_for_user(subj)
print(graphs, "GRAPHS", subj)
permissioned_graphs = set()
for _, graph, act in graphs:
if not graph.startswith("gp:"):
Expand Down
130 changes: 130 additions & 0 deletions coral/utils/casbin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from arches.app.search.components.base import SearchFilterFactory
from urllib.parse import parse_qs
from arches.app.views.search import build_search
from arches.app.search.elasticsearch_dsl_builder import Bool, Match, Query, Ids, Nested, Terms, MaxAgg, Aggregation, UpdateByQuery
from arches.app.search.search_engine_factory import SearchEngineFactory
from arches.app.search.mappings import RESOURCES_INDEX
from arches.app.models.resource import Resource
from arches_orm.models import Set, LogicalSet
import time

class SetApplicator:
def __init__(self, print_statistics, wait_for_completion):
self.print_statistics = print_statistics
self.wait = wait_for_completion

def _apply_set(self, se, set_id, set_query, resourceinstanceid=None):
results = []
for add_not_remove in (True, False):
dsl = Query(se=se)
bool_query = Bool()
if resourceinstanceid:
bool_query.must(Ids(ids=[str(resourceinstanceid)]))
if add_not_remove:
bool_query.must_not(set_query())
bool_query.must(Nested(path="sets", query=Terms(field="sets.id", terms=[str(set_id)])))
sets = [str(set_id)]
source = """
if (ctx._source.sets != null) {
for (int i=ctx._source.sets.length-1; i>=0; i--) {
if (params.logicalSets.contains(ctx._source.sets[i].id)) {
ctx._source.sets.remove(i);
}
}
}
"""
else:
bool_query.must(set_query())
bool_query.must_not(Nested(path="sets", query=Terms(field="sets.id", terms=[str(set_id)])))
source = "ctx._source.sets.addAll(params.logicalSets)"
sets = [{"id": str(set_id)}]
dsl.add_query(bool_query)
update_by_query = UpdateByQuery(se=se, query=dsl, script={
"lang": "painless",
"source": source,
"params": {
"logicalSets": sets
}
})
results.append(update_by_query.run(index=RESOURCES_INDEX, wait_for_completion=False))
return results

def apply_sets(self, resourceinstanceid=None):
"""Apply set mappings to resources.
Run update-by-queries to mark/unmark sets against resources in Elasticsearch.
"""

from arches.app.search.search_engine_factory import SearchEngineInstance as _se

logical_sets = LogicalSet.all()
results = []
print("Print statistics?", self.print_statistics)
for logical_set in logical_sets:
if logical_set.member_definition:
# user=True is shorthand for "do not restrict by user"
parameters = parse_qs(logical_set.member_definition)
for key, value in parameters.items():
if len(value) != 1:
raise RuntimeError("Each filter type must appear exactly once")
parameters[key] = value[0]
def _logical_set_query():
_, _, inner_dsl = build_search(
for_export=False,
pages=False,
total=None,
resourceinstanceid=None,
load_tiles=False,
user=True,
provisional_filter=[],
parameters=parameters,
permitted_nodegroups=True # This should be ignored as user==True
)
return inner_dsl.dsl["query"]
if self.print_statistics:
dsl = Query(se=_se)
dsl.add_query(_logical_set_query())
count = dsl.count(index=RESOURCES_INDEX)
print("Logical Set:", logical_set.id)
print("Definition:", logical_set.member_definition)
print("Count:", count)
results = self._apply_set(_se, f"l:{logical_set.id}", _logical_set_query, resourceinstanceid=resourceinstanceid)
if self.wait:
self.wait_for_completion(_se, results)
if self.print_statistics:
dsl = Query(se=_se)
bool_query = Bool()
bool_query.must(Nested(path="sets", query=Terms(field="sets.id", terms=[f"l:{logical_set.id}"])))
if resourceinstanceid:
bool_query.must(Ids(ids=[str(resourceinstanceid)]))
dsl.add_query(bool_query)
count = dsl.count(index=RESOURCES_INDEX)
print("Applies to by search:", count)

sets = Set.all()
for regular_set in sets:
if regular_set.members:
# user=True is shorthand for "do not restrict by user"
members = [str(member.id) for member in regular_set.members]
if not resourceinstanceid or str(resourceinstanceid) in members:
def _regular_set_query():
query = Query(se=_se)
bool_query = Bool()
bool_query.must(Terms(field="_id", terms=members))
query.add_query(bool_query)
return query.dsl["query"]
results = self._apply_set(_se, f"r:{regular_set.id}", _regular_set_query, resourceinstanceid=resourceinstanceid)
if self.wait:
self.wait_for_completion(_se, results)

def wait_for_completion(self, _se, results):
tasks_client = _se.make_tasks_client()
while results:
result = results[0]
task_id = result["task"]
status = tasks_client.get(task_id=task_id)
if status["completed"]:
results.remove(result)
else:
print(task_id, "not yet completed")
time.sleep(0.1)
Loading

0 comments on commit ebc567a

Please sign in to comment.