diff --git a/core/common/tasks.py b/core/common/tasks.py index 7626a68fe..fa2c4b170 100644 --- a/core/common/tasks.py +++ b/core/common/tasks.py @@ -711,7 +711,7 @@ def update_mappings_concept(concept_id): def calculate_checksums(resource_type, resource_id): model = get_resource_class_from_resource_name(resource_type) if model: - is_source_child = model.__class__.__name__ in ('Concept', 'Mapping') + is_source_child = model.__name__ in ('Concept', 'Mapping') instance = model.objects.filter(id=resource_id).first() if instance: instance.set_checksums() diff --git a/core/common/tests.py b/core/common/tests.py index e2c0fe4a5..ff517fb3d 100644 --- a/core/common/tests.py +++ b/core/common/tests.py @@ -1,3 +1,4 @@ +import datetime import os import uuid from collections import OrderedDict @@ -18,14 +19,15 @@ from core.collections.models import CollectionReference from core.common.constants import HEAD -from core.common.tasks import delete_s3_objects, bulk_import_parallel_inline, resources_report +from core.common.tasks import delete_s3_objects, bulk_import_parallel_inline, resources_report, calculate_checksums from core.common.utils import ( compact_dict_by_values, to_snake_case, flower_get, task_exists, parse_bulk_import_task_id, to_camel_case, drop_version, is_versioned_uri, separate_version, to_parent_uri, jsonify_safe, es_get, get_resource_class_from_resource_name, flatten_dict, is_csv_file, is_url_encoded_string, to_parent_uri_from_kwargs, set_current_user, get_current_user, set_request_url, get_request_url, nested_dict_values, chunks, api_get, - split_list_by_condition, is_zip_file) + split_list_by_condition, is_zip_file, get_date_range_label, get_prev_month, from_string_to_date, get_end_of_month, + get_start_of_month, es_id_in, web_url) from core.concepts.models import Concept from core.orgs.models import Organization from core.sources.models import Source @@ -37,6 +39,8 @@ from .serializers import IdentifierSerializer from .validators import URIValidator from ..code_systems.serializers import CodeSystemDetailSerializer +from ..concepts.tests.factories import ConceptFactory, ConceptNameFactory +from ..sources.tests.factories import OrganizationSourceFactory class CustomTestRunner(ColourRunnerMixin, DiscoverRunner): @@ -52,9 +56,6 @@ class SetupTestEnvironment: class BaseTestCase(SetupTestEnvironment): @staticmethod def create_lookup_concept_classes(user=None, org=None): - from core.sources.tests.factories import OrganizationSourceFactory - from core.concepts.tests.factories import ConceptNameFactory, ConceptFactory - org = org or Organization.objects.get(mnemonic='OCL') user = user or UserProfile.objects.get(username='ocladmin') @@ -416,13 +417,34 @@ def test_api_get(self, http_get_mock): headers={'Authorization': f'Token {user.get_token()}'} ) + @patch('core.common.utils.settings') @patch('core.common.utils.requests.get') - def test_es_get(self, http_get_mock): + def test_es_get(self, http_get_mock, settings_mock): + settings_mock.ES_USER = 'es-user' + settings_mock.ES_PASSWORD = 'es-password' + settings_mock.ES_HOSTS = 'es:9200' + settings_mock.ES_SCHEME = 'http' http_get_mock.return_value = 'dummy-response' self.assertEqual(es_get('some-url', timeout=1), 'dummy-response') - http_get_mock.assert_called_once_with('http://es:9200/some-url', auth=None, timeout=1) + http_get_mock.assert_called_with( + 'http://es:9200/some-url', + auth=HTTPBasicAuth('es-user', 'es-password'), + timeout=1 + ) + + settings_mock.ES_HOSTS = None + settings_mock.ES_HOST = 'es' + settings_mock.ES_PORT = '9201' + + self.assertEqual(es_get('some-url', timeout=1), 'dummy-response') + + http_get_mock.assert_called_with( + 'http://es:9201/some-url', + auth=HTTPBasicAuth('es-user', 'es-password'), + timeout=1 + ) @patch('core.common.utils.flower_get') def test_task_exists(self, flower_get_mock): @@ -632,11 +654,14 @@ def test_jsonify_safe(self): self.assertEqual(jsonify_safe('{"foo": "bar"}'), {'foo': 'bar'}) def test_get_resource_class_from_resource_name(self): + self.assertEqual(get_resource_class_from_resource_name(None), None) self.assertEqual(get_resource_class_from_resource_name('mappings').__name__, 'Mapping') self.assertEqual(get_resource_class_from_resource_name('sources').__name__, 'Source') self.assertEqual(get_resource_class_from_resource_name('source').__name__, 'Source') self.assertEqual(get_resource_class_from_resource_name('collections').__name__, 'Collection') self.assertEqual(get_resource_class_from_resource_name('collection').__name__, 'Collection') + self.assertEqual(get_resource_class_from_resource_name('expansion').__name__, 'Expansion') + self.assertEqual(get_resource_class_from_resource_name('reference').__name__, 'CollectionReference') for name in ['orgs', 'organizations', 'org', 'ORG']: self.assertEqual(get_resource_class_from_resource_name(name).__name__, 'Organization') for name in ['user', 'USer', 'user_profile', 'USERS']: @@ -743,6 +768,7 @@ def test_is_url_encoded_string(self): self.assertTrue(is_url_encoded_string('foo')) self.assertFalse(is_url_encoded_string('foo/bar')) self.assertTrue(is_url_encoded_string('foo%2Fbar')) + self.assertTrue(is_url_encoded_string('foo%2Fbar', False)) def test_to_parent_uri_from_kwargs(self): self.assertEqual( @@ -822,6 +848,86 @@ def test_split_list_by_condition(self): self.assertEqual(include, [ref1, ref4]) self.assertEqual(exclude, [ref2, ref3]) + def test_get_date_range_label(self): + self.assertEqual( + get_date_range_label('2019-01-01', '2019-01-31'), + '01 - 31 January 2019' + ) + self.assertEqual( + get_date_range_label('2019-01-01 10:00:00', '2019-01-01 11:00:00'), + '01 - 01 January 2019' + ) + self.assertEqual( + get_date_range_label('2019-01-02 10:10:00', '2019-01-01'), + '02 - 01 January 2019' + ) + self.assertEqual( + get_date_range_label('2019-02-01 10:10:00', '2019-01-01'), + '01 February - 01 January 2019' + ) + self.assertEqual( + get_date_range_label('2019-01-01', '2020-01-01'), + '01 January 2019 - 01 January 2020' + ) + + def test_get_prev_month(self): + self.assertEqual(get_prev_month(from_string_to_date('2023-01-01')), from_string_to_date('2022-12-31')) + self.assertEqual(get_prev_month(from_string_to_date('2024-12-01')), from_string_to_date('2024-11-30')) + self.assertEqual(get_prev_month(from_string_to_date('2024-12-05')), from_string_to_date('2024-11-30')) + + def test_get_end_of_month(self): + self.assertEqual(get_end_of_month(from_string_to_date('2023-01-01')), from_string_to_date('2023-01-31')) + self.assertEqual(get_end_of_month(from_string_to_date('2024-12-01')), from_string_to_date('2024-12-31')) + self.assertEqual(get_end_of_month(from_string_to_date('2024-12-05')), from_string_to_date('2024-12-31')) + self.assertEqual(get_end_of_month(from_string_to_date('2024-11-05')), from_string_to_date('2024-11-30')) + self.assertEqual(get_end_of_month(from_string_to_date('2024-02-05')), from_string_to_date('2024-02-29')) + self.assertEqual( + get_end_of_month(from_string_to_date('2024-11-30 11:00')), from_string_to_date('2024-11-30 11:00')) + self.assertEqual( + get_end_of_month(from_string_to_date('2024-11-15 11:00')), from_string_to_date('2024-11-30 11:00')) + + def test_get_start_of_month(self): + self.assertEqual(get_start_of_month(from_string_to_date('2023-01-01')), from_string_to_date('2023-01-01')) + self.assertEqual(get_start_of_month(from_string_to_date('2024-12-31')), from_string_to_date('2024-12-01')) + self.assertEqual(get_start_of_month(from_string_to_date('2024-02-05')), from_string_to_date('2024-02-01')) + self.assertEqual(get_start_of_month(from_string_to_date('2023-02-28')), from_string_to_date('2023-02-01')) + + def test_es_id_in(self): + search = Mock(query=Mock(return_value='search')) + + self.assertEqual(es_id_in(search, []), search) + + self.assertEqual(es_id_in(search, [1, 2, 3]), 'search') + search.query.assert_called_once_with("terms", _id=[1, 2, 3]) + + @patch('core.common.utils.settings') + def test_web_url(self, settings_mock): + settings_mock.WEB_URL = 'https://ocl.org' + self.assertEqual(web_url(), 'https://ocl.org') + + settings_mock.WEB_URL = None + + for env in [None, 'development', 'ci']: + settings_mock.ENV = env + self.assertEqual(web_url(), 'http://localhost:4000') + + settings_mock.ENV = 'production' + self.assertEqual(web_url(), 'https://app.openconceptlab.org') + + settings_mock.ENV = 'staging' + self.assertEqual(web_url(), 'https://app.staging.openconceptlab.org') + + settings_mock.ENV = 'foo' + self.assertEqual(web_url(), 'https://app.foo.openconceptlab.org') + + def test_from_string_to_date(self): + self.assertEqual( + from_string_to_date('2023-02-28'), datetime.datetime(2023, 2, 28)) + self.assertEqual( + from_string_to_date('2023-02-28 10:00:00'), datetime.datetime(2023, 2, 28, 10)) + self.assertEqual( + from_string_to_date('2023-02-29'), None) + class BaseModelTest(OCLTestCase): def test_model_name(self): @@ -894,6 +1000,39 @@ def test_resources_report(self, email_message_mock): self.assertTrue('Please find attached resources report of' in call_args['body']) self.assertTrue('for the period of' in call_args['body']) + def test_calculate_checksums(self): + concept = ConceptFactory() + concept_prev_latest = concept.get_latest_version() + Concept.create_new_version_for( + instance=concept.clone(), + data={ + 'names': [{'locale': 'en', 'name': 'English', 'locale_preferred': True}] + }, + user=concept.created_by, + create_parent_version=False + ) + concept_latest = concept.get_latest_version() + + Concept.objects.filter(id__in=[concept.id, concept_latest.id, concept_prev_latest.id]).update(checksums={}) + + concept.refresh_from_db() + concept_prev_latest.refresh_from_db() + concept_latest.refresh_from_db() + + self.assertEqual(concept.checksums, {}) + self.assertEqual(concept_prev_latest.checksums, {}) + self.assertEqual(concept_latest.checksums, {}) + + calculate_checksums('concepts', concept_prev_latest.id) + + concept.refresh_from_db() + concept_prev_latest.refresh_from_db() + concept_latest.refresh_from_db() + + self.assertEqual(concept_prev_latest.checksums, {'smart': ANY, 'standard': ANY}) + self.assertEqual(concept_latest.checksums, {'smart': ANY, 'standard': ANY}) + self.assertEqual(concept.checksums, {'smart': ANY, 'standard': ANY}) + class URIValidatorTest(OCLTestCase): validator = URIValidator() diff --git a/core/common/utils.py b/core/common/utils.py index d2c5ab1ee..8d7a07a3d 100644 --- a/core/common/utils.py +++ b/core/common/utils.py @@ -22,7 +22,6 @@ from django.urls import NoReverseMatch, reverse, get_resolver from django.utils import timezone from djqscsv import csv_file_for -from elasticsearch_dsl import Q as es_Q from pydash import flatten, compact, get from requests import ConnectTimeout from requests.auth import HTTPBasicAuth @@ -264,8 +263,6 @@ def write_export_file( batch_queryset = concepts_qs.order_by('-concept_id')[start:end] logger.info('Done serializing concepts.') - else: - logger.info(f'{resource_name} has no concepts to serialize.') if is_collection: references_qs = version.references @@ -290,8 +287,6 @@ def write_export_file( if end != total_references: out.write(', ') logger.info('Done serializing references.') - else: - logger.info(f'{resource_name} has no references to serialize.') with open('export.json', 'a') as out: out.write('], "mappings": [') @@ -323,8 +318,6 @@ def write_export_file( batch_queryset = mappings_qs.order_by('-mapping_id')[start:end] logger.info('Done serializing mappings.') - else: - logger.info(f'{resource_name} has no mappings to serialize.') with open('export.json', 'a') as out: end_time = str(round((time.time() - start_time) + 2, 2)) + 'secs' @@ -362,10 +355,6 @@ def get_api_base_url(): return settings.API_BASE_URL -def get_api_internal_base_url(): - return settings.API_INTERNAL_BASE_URL - - def to_snake_case(string): # from https://www.geeksforgeeks.org/python-program-to-convert-camel-case-string-to-snake-case/ return ''.join(['_' + i.lower() if i.isupper() else i for i in string]).lstrip('_') @@ -621,12 +610,14 @@ def get_resource_class_from_resource_name(resource): # pylint: disable=too-many def get_content_type_from_resource_name(resource): + content_type = None + model = get_resource_class_from_resource_name(resource) if model: from django.contrib.contenttypes.models import ContentType - return ContentType.objects.get_for_model(model) + content_type = ContentType.objects.get_for_model(model) - return None + return content_type def flatten_dict(dikt, parent_key='', sep='__'): @@ -663,14 +654,14 @@ def get_celery_once_lock_key(name, args): def guess_extension(file=None, name=None): - if not file and not name: - return None - if file: - name = file.name - _, extension = os.path.splitext(name) - - if not extension: - extension = mimetypes.guess_extension(name) + extension = None + if file or name: + if file: + name = file.name + _, extension = os.path.splitext(name) + + if not extension: + extension = mimetypes.guess_extension(name) return extension @@ -766,14 +757,6 @@ def get_request_url(): return request_url -def named_tuple_fetchall(cursor): - """Return all rows from a cursor as a namedtuple""" - from collections import namedtuple - desc = cursor.description - nt_result = namedtuple('Result', [col[0] for col in desc]) - return [nt_result(*row) for row in cursor.fetchall()] - - def nested_dict_values(_dict): for value in _dict.values(): if isinstance(value, dict): @@ -794,27 +777,6 @@ def es_id_in(search, ids): return search -def get_es_wildcard_search_criterion(search_str, name_attr='name'): - def get_query(_str): - return es_Q( - "wildcard", id={'value': _str, 'boost': 2} - ) | es_Q( - "wildcard", **{name_attr: {'value': _str, 'boost': 5}} - ) | es_Q( - "query_string", query=f"*{_str}*" - ) - - if search_str: - words = search_str.split() - criterion = get_query(words[0]) - for word in words[1:]: - criterion |= get_query(word) - else: - criterion = get_query(search_str) - - return criterion - - def es_to_pks(search): # doesn't care about the order default_limit = 25