Skip to content

Commit

Permalink
validate group type params
Browse files Browse the repository at this point in the history
  • Loading branch information
chrismeyersfsu committed Mar 21, 2018
1 parent 17795f8 commit 1c578cd
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 10 deletions.
7 changes: 5 additions & 2 deletions awx/conf/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _preload_cache(self):
settings_to_cache['_awx_conf_preload_expires'] = self._awx_conf_preload_expires
self.cache.set_many(settings_to_cache, timeout=SETTING_CACHE_TIMEOUT)

def _get_local(self, name):
def _get_local(self, name, validate=True):
self._preload_cache()
cache_key = Setting.get_cache_key(name)
try:
Expand Down Expand Up @@ -368,7 +368,10 @@ def _get_local(self, name):
field.run_validators(internal_value)
return internal_value
else:
return field.run_validation(value)
if validate:
return field.run_validation(value)
else:
return value
except Exception:
logger.warning(
'The current value "%r" for setting "%s" is invalid.',
Expand Down
61 changes: 53 additions & 8 deletions awx/sso/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

# Django Auth LDAP
import django_auth_ldap.config
from django_auth_ldap.config import LDAPSearch, LDAPSearchUnion
from django_auth_ldap.config import (
LDAPSearch,
LDAPSearchUnion,
LDAPGroupType,
)

# This must be imported so get_subclasses picks it up
from awx.sso.ldap_group_types import PosixUIDGroupType # noqa
Expand All @@ -28,6 +32,25 @@ def get_subclasses(cls):
yield subclass


class DependsOnMixin():
def get_depends_on(self):
"""
Get the value of the dependent field.
First try to find the value in the request.
Then fall back to the raw value from the setting in the DB.
"""
from django.conf import settings
dependent_key = iter(self.depends_on).next()

if self.context:
request = self.context.get('request', None)
if request and request.data and \
request.data.get(dependent_key, None):
return request.data.get(dependent_key)
res = settings._get_local(dependent_key, validate=False)
return res


class AuthenticationBackendsField(fields.StringListField):

# Mapping of settings that must be set in order to enable each
Expand Down Expand Up @@ -326,7 +349,15 @@ def to_internal_value(self, data):
return data


class LDAPGroupTypeField(fields.ChoiceField):
VALID_GROUP_TYPE_PARAMS_MAP = {
'LDAPGroupType': ['name_attr'],
'MemberDNGroupType': ['name_attr', 'member_attr'],
'PosixUIDGroupType': ['name_attr', 'ldap_group_user_attr'],
}



class LDAPGroupTypeField(fields.ChoiceField, DependsOnMixin):

default_error_messages = {
'type_error': _('Expected an instance of LDAPGroupType but got {input_type} instead.'),
Expand Down Expand Up @@ -357,8 +388,7 @@ def find_class_in_modules(class_name):
if not data:
return None

from django.conf import settings
params = getattr(settings, iter(self.depends_on).next(), None) or {}
params = self.get_depends_on() or {}
cls = find_class_in_modules(data)
if not cls:
return None
Expand All @@ -370,8 +400,9 @@ def find_class_in_modules(class_name):
# took a parameter.
params_sanitized = dict()
if isinstance(cls, LDAPGroupType):
if 'name_attr' in params:
params_sanitized['name_attr'] = params['name_attr']
for k in VALID_GROUP_TYPE_PARAMS_MAP['LDAPGroupType']:
if k in params:
params_sanitized['name_attr'] = params['name_attr']

if data.endswith('MemberDNGroupType'):
params.setdefault('member_attr', 'member')
Expand All @@ -383,8 +414,22 @@ def find_class_in_modules(class_name):
return cls(**params_sanitized)


class LDAPGroupTypeParamsField(fields.DictField):
pass
class LDAPGroupTypeParamsField(fields.DictField, DependsOnMixin):
default_error_messages = {
'invalid_keys': _('Invalid key(s): {invalid_keys}.'),
}

def to_internal_value(self, value):
value = super(LDAPGroupTypeParamsField, self).to_internal_value(value)
if not value:
return value
group_type_str = self.get_depends_on()
group_type_str = group_type_str or ''
invalid_keys = (set(value.keys()) - set(VALID_GROUP_TYPE_PARAMS_MAP.get(group_type_str, 'LDAPGroupType')))
if invalid_keys:
keys_display = json.dumps(list(invalid_keys)).lstrip('[').rstrip(']')
self.fail('invalid_keys', invalid_keys=keys_display)
return value


class LDAPUserFlagsField(fields.DictField):
Expand Down
21 changes: 21 additions & 0 deletions awx/sso/tests/unit/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@

import pytest
import mock

from rest_framework.exceptions import ValidationError

from awx.sso.fields import (
SAMLOrgAttrField,
SAMLTeamAttrField,
LDAPGroupTypeParamsField,
)


Expand Down Expand Up @@ -80,3 +82,22 @@ def test_internal_value_invalid(self, data, expected):
field.to_internal_value(data)
assert str(e.value) == str(expected)


class TestLDAPGroupTypeParamsField():

@pytest.mark.parametrize("group_type, data, expected", [
('LDAPGroupType', {'name_attr': 'user', 'bob': ['a', 'b'], 'scooter': 'hello'},
ValidationError('Invalid key(s): "bob", "scooter".')),
('MemberDNGroupType', {'name_attr': 'user', 'member_attr': 'west', 'bob': ['a', 'b'], 'scooter': 'hello'},
ValidationError('Invalid key(s): "bob", "scooter".')),
('PosixUIDGroupType', {'name_attr': 'user', 'member_attr': 'west', 'ldap_group_user_attr': 'legacyThing',
'bob': ['a', 'b'], 'scooter': 'hello'},
ValidationError('Invalid key(s): "bob", "member_attr", "scooter".')),
])
def test_internal_value_invalid(self, group_type, data, expected):
field = LDAPGroupTypeParamsField()
field.get_depends_on = mock.MagicMock(return_value=group_type)

with pytest.raises(type(expected)) as e:
field.to_internal_value(data)
assert str(e.value) == str(expected)

0 comments on commit 1c578cd

Please sign in to comment.