From bd2833a2bfebb20ea0e3f04bfaffa9b7ff33baf8 Mon Sep 17 00:00:00 2001 From: Brandon Zhu <38299259+brandonzhu09@users.noreply.github.com> Date: Tue, 9 Apr 2024 10:40:47 -0400 Subject: [PATCH] Feature/models testing (#18) * added schema testing with django models * added back env variables in actions * fixed actions --- ...django_views.yaml => test_django_api.yaml} | 8 +- pycodestyle.cfg | 1 + requirements.txt | 2 +- src/evagram/database/dataset.json | 24 +- src/evagram/database/input_tool.py | 238 ---------------- .../backend/api/migrations/0001_initial.py | 47 ++-- src/evagram/website/backend/api/models.py | 6 +- .../website/backend/api/test_models.py | 74 +++++ .../backend/api/{tests.py => test_views.py} | 1 - .../website/backend/backend/settings.py | 16 +- .../test_database.yaml => test_database.yaml | 0 tests/test_input_tool.py | 261 ------------------ 12 files changed, 120 insertions(+), 558 deletions(-) rename .github/workflows/{test_django_views.yaml => test_django_api.yaml} (80%) delete mode 100644 src/evagram/database/input_tool.py create mode 100644 src/evagram/website/backend/api/test_models.py rename src/evagram/website/backend/api/{tests.py => test_views.py} (99%) rename .github/workflows/test_database.yaml => test_database.yaml (100%) delete mode 100644 tests/test_input_tool.py diff --git a/.github/workflows/test_django_views.yaml b/.github/workflows/test_django_api.yaml similarity index 80% rename from .github/workflows/test_django_views.yaml rename to .github/workflows/test_django_api.yaml index 59e802f..db407da 100644 --- a/.github/workflows/test_django_views.yaml +++ b/.github/workflows/test_django_api.yaml @@ -1,4 +1,4 @@ -name: Test Django API Views +name: Test Database Tool on: pull_request: types: @@ -35,11 +35,7 @@ jobs: python -m pip install --upgrade pip pip install . -r requirements.txt - - name: Run Test Django API Views + - name: Run Django Model Tests env: - DB_HOST: ${{secrets.DB_HOST}} - DB_PORT: ${{secrets.DB_PORT}} - DB_NAME: ${{secrets.DB_NAME}} - DB_USER: ${{secrets.DB_USER}} DB_PASSWORD: ${{secrets.DB_PASSWORD}} run: python ./src/evagram/website/backend/manage.py test api diff --git a/pycodestyle.cfg b/pycodestyle.cfg index 722fa7d..db4875f 100644 --- a/pycodestyle.cfg +++ b/pycodestyle.cfg @@ -9,3 +9,4 @@ max-line-length = 100 indent-size = 4 statistics = True ignore = W503, W504 +exclude = __pycache__, src/evagram/website/backend/api/migrations diff --git a/requirements.txt b/requirements.txt index 17ca94a..c9eb4d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ autopep8==2.0.4 Django==4.2.10 django-cors-headers==4.3.1 djangorestframework==3.14.0 -psycopg2==2.9.9 +psycopg2-binary==2.9.9 pycodestyle>=2.8.0 python-dotenv==1.0.0 pytz==2023.3.post1 diff --git a/src/evagram/database/dataset.json b/src/evagram/database/dataset.json index 455c2a2..2bfef9c 100644 --- a/src/evagram/database/dataset.json +++ b/src/evagram/database/dataset.json @@ -61,38 +61,38 @@ { "plot_id": 121, "plot_file": "brightnessTemperature_4_effectiveerror-vs-gsifinalerror.pkl", - "experiment_id": 12 + "experiment_id": 12, + "observation_name": "amsua_aqua" }, { "plot_id": 322, "plot_file": "brightnessTemperature_10_hofx-vs-gsihofxbc.pkl", - "experiment_id": 12 + "experiment_id": 12, + "observation_name": "amsua_aqua" }, { "plot_id": 323, "plot_file": "brightnessTemperature_8_hofx-vs-gsihofxbc.pkl", - "experiment_id": 3 + "experiment_id": 3, + "observation_name": "amsua_n18" }, { "plot_id": 114, "plot_file": "brightnessTemperature_8_effectiveerror-vs-gsifinalerror.pkl", - "experiment_id": 3 + "experiment_id": 3, + "observation_name": "amsua_n18" }, { "plot_id": 165, "plot_file": "brightnessTemperature_86_effectiveerror-vs-gsifinalerror.pkl", - "experiment_id": 1 + "experiment_id": 1, + "observation_name": "cris-fsr_n20" }, { "plot_id": 966, "plot_file": "windEastward__effectiveerrordiff-vs-gsifinalerror.pkl", - "experiment_id": 96 + "experiment_id": 96, + "observation_name": "satwind" } - ], - "observation_dirs": [ - "amsua_aqua", - "amsua_n18", - "cris-fsr_n20", - "satwind" ] } \ No newline at end of file diff --git a/src/evagram/database/input_tool.py b/src/evagram/database/input_tool.py deleted file mode 100644 index b0b60d5..0000000 --- a/src/evagram/database/input_tool.py +++ /dev/null @@ -1,238 +0,0 @@ -import pickle -import json -import os -import sys -import argparse -from pathlib import Path -import psycopg2 -from dotenv import load_dotenv - -load_dotenv() - -# environment variables for connecting to database -db_host = os.environ.get('DB_HOST') -db_port = os.environ.get('DB_PORT') -db_name = os.environ.get('DB_NAME') -db_user = os.environ.get('DB_USER') -db_password = os.environ.get('DB_PASSWORD') - -# default path configurations -EXPERIMENT_DATA_PATH = './tests/eva' -DATASET_PATH = './src/evagram/database/' -PROCEDURES_PATH = './src/evagram/database/sql' - - -def main(args): - global EXPERIMENT_DATA_PATH - parser = argparse.ArgumentParser() - parser.add_argument("experiment_path") - args = parser.parse_args(args) - experiment_path = Path(args.experiment_path) - if experiment_path.exists(): - EXPERIMENT_DATA_PATH = args.experiment_path - - conn = psycopg2.connect( - host=db_host, - port=db_port, - dbname=db_name, - user=db_user, - password=db_password - ) - cur = conn.cursor() - - create_procedures(cur) - drop_tables(cur) - create_tables(cur) - load_dataset_to_db(cur) - - conn.commit() - cur.close() - conn.close() - - -def create_procedures(cur): - for proc in os.listdir(PROCEDURES_PATH): - proc_file = os.path.join(PROCEDURES_PATH, proc) - if os.path.isfile(proc_file) and proc.startswith("proc_") and proc.endswith(".sql"): - contents = open(proc_file, 'r') - cur.execute(contents.read()) - contents.close() - - -def create_tables(cur): - # Users table - cur.execute("CALL public.create_owners();") - # Experiments table - cur.execute("CALL public.create_experiments();") - # Groups table - cur.execute("CALL public.create_groups();") - # Variables table - cur.execute("CALL public.create_variables();") - # Observations table - cur.execute("CALL public.create_observations();") - # Plots table - cur.execute("CALL public.create_plots();") - - -def drop_tables(cur): - cur.execute("DROP TABLE IF EXISTS owners CASCADE") - cur.execute("DROP TABLE IF EXISTS experiments CASCADE") - cur.execute("DROP TABLE IF EXISTS plots CASCADE") - cur.execute("DROP TABLE IF EXISTS groups CASCADE") - cur.execute("DROP TABLE IF EXISTS observations CASCADE") - cur.execute("DROP TABLE IF EXISTS variables CASCADE") - - -def load_dataset_to_db(cur): - with open(os.path.join(DATASET_PATH, "dataset.json"), 'rb') as dataset: - sample_dataset = json.load(dataset) - for owner in sample_dataset['owners']: - add_user(cur, owner) - for experiment in sample_dataset['experiments']: - add_experiment(cur, experiment) - for plot in (sample_dataset['plots']): - add_plot(cur, plot, sample_dataset['observation_dirs']) - - -def get_observation_name(obs_dirs, filename): - # TODO: resolve duplicate filenames from different observations - for observation in obs_dirs: - obs_dir_path = os.path.join(EXPERIMENT_DATA_PATH, observation) - for plot in os.listdir(obs_dir_path): - if filename == plot: - return observation - return None - - -def insert_table_record(cur, data, table): - cur.execute("SELECT * FROM {} LIMIT 0".format(table)) - colnames = [desc[0] for desc in cur.description] - # filter data to contain only existing columns in table - data = {k: v for (k, v) in data.items() if k in colnames} - - query = "INSERT INTO {} (".format(table) - query += ', '.join(data) - query += ") VALUES (" - query += ', '.join(["%s" for _ in range(len(data))]) - query += ")" - - cur.execute(query, tuple(data.values())) - - -def add_user(cur, user_obj): - # check not null constraints - # TODO: owner_id made optional for user to provide - required = {'owner_id', 'username'} - difference = required.difference(user_obj) - if len(difference) > 0: - print("Missing required columns for owners table: {}".format(difference)) - return 1 - else: - insert_table_record(cur, user_obj, "owners") - return 0 - - -def delete_user(cur, username): - cur.execute("DELETE FROM owners WHERE username=%s", (username,)) - - -def add_experiment(cur, experiment_obj): - # check not null constraints - # TODO: experiment_id made optional for user to provide - required = {'experiment_id', 'owner_id'} - difference = required.difference(experiment_obj) - if len(difference) > 0: - print("Missing required columns for experiment table: {}".format(difference)) - return 1 - else: - insert_table_record(cur, experiment_obj, "experiments") - return 0 - - -def add_plot(cur, plot_obj, observation_dirs): - # check not null constraints - # TODO: plot_id made optional for user to provide - required = {'plot_id', 'plot_file', 'experiment_id'} - difference = required.difference(plot_obj) - if len(difference) > 0: - print("Missing required columns for plots table: {}".format(difference)) - return 1 - - plot_filename = plot_obj['plot_file'] - observation_name = get_observation_name(observation_dirs, plot_filename) - plot_file_path = os.path.join( - EXPERIMENT_DATA_PATH, observation_name, plot_filename) - - with open(plot_file_path, 'rb') as file: - dictionary = pickle.load(file) - - # extract the div and script components - div = dictionary['div'] - script = dictionary['script'] - - # parse filename for components variable name, channel, and group name - filename_no_extension = os.path.splitext(plot_filename)[0] - plot_components = filename_no_extension.split("_") - - var_name = plot_components[0] - channel = plot_components[1] if plot_components[1] != '' else None - group_name = plot_components[2] - - # insert observation, variable, group dynamically if not exist in database - cur.execute("SELECT observation_id FROM observations WHERE observation_name=%s", - (observation_name,)) - new_observation = len(cur.fetchall()) == 0 - cur.execute( - """SELECT variable_id FROM variables WHERE variable_name=%s - AND (channel=%s OR channel IS NULL)""", - (var_name, channel)) - new_variable = len(cur.fetchall()) == 0 - cur.execute("SELECT group_id FROM groups WHERE group_name=%s", (group_name,)) - new_group = len(cur.fetchall()) == 0 - - if new_observation: - observation_obj = { - "observation_name": observation_name, - } - insert_table_record(cur, observation_obj, "observations") - - if new_variable: - variable_obj = { - "variable_name": var_name, - "channel": channel - } - insert_table_record(cur, variable_obj, "variables") - - if new_group: - group_obj = { - "group_name": group_name - } - insert_table_record(cur, group_obj, "groups") - - # get the observation, variable, group ids - cur.execute("SELECT observation_id FROM observations WHERE observation_name=%s", - (observation_name,)) - observation_id = cur.fetchone()[0] - cur.execute( - """SELECT variable_id FROM variables WHERE variable_name=%s - AND (channel=%s OR channel IS NULL)""", - (var_name, channel)) - variable_id = cur.fetchone()[0] - cur.execute("SELECT group_id FROM groups WHERE group_name=%s", (group_name,)) - group_id = cur.fetchone()[0] - - # modify plot object - plot_obj["div"] = div - plot_obj["script"] = script - plot_obj["observation_id"] = observation_id - plot_obj["group_id"] = group_id - plot_obj["variable_id"] = variable_id - - # insert plot to database - insert_table_record(cur, plot_obj, "plots") - - return 0 - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/src/evagram/website/backend/api/migrations/0001_initial.py b/src/evagram/website/backend/api/migrations/0001_initial.py index 8d8b702..1a67d5a 100644 --- a/src/evagram/website/backend/api/migrations/0001_initial.py +++ b/src/evagram/website/backend/api/migrations/0001_initial.py @@ -1,7 +1,7 @@ -# Generated by Django 5.0.1 on 2024-02-07 14:30 +# Generated by Django 4.2.10 on 2024-03-24 02:17 -import django.db.models.deletion from django.db import migrations, models +import django.db.models.deletion class Migration(migrations.Migration): @@ -36,27 +36,14 @@ class Migration(migrations.Migration): name='Owners', fields=[ ('owner_id', models.AutoField(primary_key=True, serialize=False)), - ('first_name', models.CharField(blank=True, null=True)), - ('last_name', models.CharField(blank=True, null=True)), - ('username', models.CharField(unique=True)), + ('first_name', models.CharField(null=True)), + ('last_name', models.CharField(null=True)), + ('username', models.CharField(default='null', unique=True)), ], options={ 'db_table': 'owners', }, ), - migrations.CreateModel( - name='Experiments', - fields=[ - ('experiment_id', models.AutoField(primary_key=True, serialize=False)), - ('experiment_name', models.CharField(default='null')), - ('owner', models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, to='api.owners')), - ], - options={ - 'db_table': 'experiments', - 'unique_together': {('experiment_name', 'owner')}, - }, - ), migrations.CreateModel( name='Variables', fields=[ @@ -69,20 +56,28 @@ class Migration(migrations.Migration): 'unique_together': {('variable_name', 'channel')}, }, ), + migrations.CreateModel( + name='Experiments', + fields=[ + ('experiment_id', models.AutoField(primary_key=True, serialize=False)), + ('experiment_name', models.CharField(default='null')), + ('owner', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='api.owners')), + ], + options={ + 'db_table': 'experiments', + 'unique_together': {('experiment_name', 'owner')}, + }, + ), migrations.CreateModel( name='Plots', fields=[ ('plot_id', models.AutoField(primary_key=True, serialize=False)), ('div', models.CharField(blank=True, null=True)), ('script', models.CharField(blank=True, null=True)), - ('experiment', models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, to='api.experiments')), - ('group', models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, to='api.groups')), - ('observation', models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, to='api.observations')), - ('variable', models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, to='api.variables')), + ('experiment', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='api.experiments')), + ('group', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='api.groups')), + ('observation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='api.observations')), + ('variable', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='api.variables')), ], options={ 'db_table': 'plots', diff --git a/src/evagram/website/backend/api/models.py b/src/evagram/website/backend/api/models.py index 9f86f42..137d9f1 100644 --- a/src/evagram/website/backend/api/models.py +++ b/src/evagram/website/backend/api/models.py @@ -39,9 +39,9 @@ class Meta: class Owners(models.Model): owner_id = models.AutoField(primary_key=True) - first_name = models.CharField(blank=True, null=True) - last_name = models.CharField(blank=True, null=True) - username = models.CharField(unique=True) + first_name = models.CharField(null=True) + last_name = models.CharField(null=True) + username = models.CharField(null=False, unique=True, default="null") class Meta: db_table = 'owners' diff --git a/src/evagram/website/backend/api/test_models.py b/src/evagram/website/backend/api/test_models.py new file mode 100644 index 0000000..5e9fdba --- /dev/null +++ b/src/evagram/website/backend/api/test_models.py @@ -0,0 +1,74 @@ +from django.test import TestCase +import django.db +from api.models import * + + +class TestModels(TestCase): + fixtures = ["test_data.json"] + + def test_insert_duplicate_owner(self): + with self.assertRaises(django.db.IntegrityError): + Owners.objects.create(username='jdoe') + Owners.objects.create(username='jdoe') + + def test_delete_owner_cascade(self): + owner = Owners.objects.get(pk=1) + owner.delete() + self.assertEqual(0, len(Owners.objects.filter(pk=1))) + self.assertEqual(0, len(Experiments.objects.filter(owner=1))) + + def test_insert_experiment_insufficient_fields(self): + with self.assertRaises(django.db.IntegrityError): + Experiments.objects.create(experiment_name="experiment1") + + def test_insert_experiment_invalid_owner(self): + with self.assertRaises(ValueError): + Experiments.objects.create(experiment_name="experiment1", owner=-1) + + def test_insert_duplicate_experiment(self): + with self.assertRaises(django.db.IntegrityError): + owner = Owners.objects.create(username='jdoe') + experiment1 = Experiments.objects.create(experiment_name="experiment1", owner=owner) + experiment2 = Experiments.objects.create(experiment_name="experiment1", owner=owner) + + def test_delete_experiment_cascade(self): + experiment = Experiments.objects.get(pk=1) + experiment.delete() + self.assertEqual(0, len(Experiments.objects.filter(pk=1))) + self.assertEqual(0, len(Plots.objects.filter(experiment=1))) + + def test_insert_plot_insufficient_fields(self): + # Missing fields: group, variable + with self.assertRaises(django.db.IntegrityError): + experiment = Experiments.objects.get(pk=1) + observation = Observations.objects.get(pk=1) + Plots.objects.create(experiment=experiment, observation=observation) + + def test_insert_plot_invalid_fields(self): + # Invalid fields: experiment, observation + with self.assertRaises(ValueError): + group = Groups.objects.get(pk=1) + variable = Variables.objects.get(pk=1) + Plots.objects.create(experiment=-1, observation=-1, group=group, variable=variable) + + def test_insert_duplicate_observation(self): + with self.assertRaises(django.db.IntegrityError): + Observations.objects.create(observation_name="observation1") + Observations.objects.create(observation_name="observation1") + + def test_query_existing_plots(self): + # get all amsua_n18 plots in experiment "experiment_iv_1" where the user is thamzey + owner = Owners.objects.get(username="thamzey") + experiment = Experiments.objects.get(experiment_name="experiment_iv_1", owner=owner) + observation = Observations.objects.get(observation_name="amsua_n18") + queryset = Plots.objects.filter(experiment=experiment, observation=observation) + self.assertEqual(2, len(queryset)) + self.assertTrue(Plots.objects.get(pk=114) in queryset) + self.assertTrue(Plots.objects.get(pk=323) in queryset) + + def test_query_nonexistent_plots(self): + # get all satwind plots in experiment "experiment_iv_1" + experiment = Experiments.objects.get(experiment_name="experiment_iv_1") + observation = Observations.objects.get(observation_name="satwind") + queryset = Plots.objects.filter(experiment=experiment, observation=observation) + self.assertEqual(0, len(queryset)) diff --git a/src/evagram/website/backend/api/tests.py b/src/evagram/website/backend/api/test_views.py similarity index 99% rename from src/evagram/website/backend/api/tests.py rename to src/evagram/website/backend/api/test_views.py index 5c0ec98..2091d2c 100644 --- a/src/evagram/website/backend/api/tests.py +++ b/src/evagram/website/backend/api/test_views.py @@ -1,5 +1,4 @@ from django.test import TestCase -# from evagram.database import input_tool from api.models import Owners, Plots diff --git a/src/evagram/website/backend/backend/settings.py b/src/evagram/website/backend/backend/settings.py index 16a9547..34d8343 100644 --- a/src/evagram/website/backend/backend/settings.py +++ b/src/evagram/website/backend/backend/settings.py @@ -15,11 +15,7 @@ import os load_dotenv() -db_host = os.environ.get('DB_HOST') -db_port = os.environ.get('DB_PORT') -db_name = os.environ.get('DB_NAME') -db_user = os.environ.get('DB_USER') -db_password = os.environ.get('DB_PASSWORD') +pg_password = os.environ.get('DB_PASSWORD') # Build paths inside the project like this: BASE_DIR / 'subdir'. BASE_DIR = Path(__file__).resolve().parent.parent @@ -92,11 +88,11 @@ DATABASES = { 'default': { 'ENGINE': 'django.db.backends.postgresql', - 'NAME': db_name, - 'USER': db_user, - 'PASSWORD': db_password, - 'HOST': db_host, - 'PORT': db_port + 'NAME': 'plots', + 'USER': 'postgres', + 'HOST': 'localhost', + 'PORT': 5432, + 'PASSWORD': pg_password } } diff --git a/.github/workflows/test_database.yaml b/test_database.yaml similarity index 100% rename from .github/workflows/test_database.yaml rename to test_database.yaml diff --git a/tests/test_input_tool.py b/tests/test_input_tool.py deleted file mode 100644 index bae0c06..0000000 --- a/tests/test_input_tool.py +++ /dev/null @@ -1,261 +0,0 @@ -from evagram.database import input_tool -import sys -import unittest -import psycopg2 -from dotenv import load_dotenv -import os - -load_dotenv() - -# environment variables for connecting to database -db_host = os.environ.get('DB_HOST') -db_port = os.environ.get('DB_PORT') -db_name = os.environ.get('DB_NAME') -db_user = os.environ.get('DB_USER') -db_password = os.environ.get('DB_PASSWORD') - -conn = psycopg2.connect( - host=db_host, - port=db_port, - dbname=db_name, - user=db_user, - password=db_password -) - - -class TestDatabaseInputTool(unittest.TestCase): - def setUp(self): - self.cur = conn.cursor() - input_tool.main(['tests/eva']) - - self.cur.execute( - """SELECT setval('owners_owner_id_seq', - (SELECT MAX(owner_id) FROM owners)+1)""") - self.cur.execute( - """SELECT setval('experiments_experiment_id_seq', - (SELECT MAX(experiment_id) FROM experiments)+1)""") - self.cur.execute( - """SELECT setval('plots_plot_id_seq', - (SELECT MAX(plot_id) FROM plots)+1)""") - self.cur.execute( - """SELECT setval('observations_observation_id_seq', - (SELECT MAX(observation_id) FROM observations)+1)""") - - def tearDown(self): - conn.rollback() - self.cur.close() - - def test_InsertOwnerExpected(self): - user_obj = { - "username": "jdoe", - "first_name": "John", - "last_name": "Doe" - } - input_tool.insert_table_record(self.cur, user_obj, "owners") - self.cur.execute( - "SELECT (username) FROM owners WHERE username=%s", ("jdoe",)) - assert len(self.cur.fetchall()) == 1 - - def test_InsertSameOwner(self): - with self.assertRaises(psycopg2.errors.UniqueViolation): - user_obj = { - "username": "bzhu", - "first_name": "Brandon", - "last_name": "Zhu" - } - input_tool.insert_table_record(self.cur, user_obj, "owners") - - def test_InsertOwnerNoUsername(self): - with self.assertRaises(psycopg2.errors.NotNullViolation): - user_obj = { - "first_name": "John", - "last_name": "Doe" - } - input_tool.insert_table_record(self.cur, user_obj, "owners") - - def test_DeleteOwnerAndExperiments(self): - self.cur.execute("DELETE FROM owners WHERE owner_id=%s", (1,)) - self.cur.execute( - "SELECT (username) FROM owners WHERE owner_id=%s", (1,)) - assert len(self.cur.fetchall()) == 0 - self.cur.execute( - "SELECT (experiment_id) FROM experiments WHERE owner_id=%s", (1,)) - assert len(self.cur.fetchall()) == 0 - - def test_InsertExperimentExpected(self): - experiment_obj = { - "experiment_name": "control", - "owner_id": 1 - } - input_tool.insert_table_record(self.cur, experiment_obj, "experiments") - self.cur.execute( - "SELECT (experiment_name) FROM experiments WHERE experiment_name=%s AND owner_id=%s", - ("control", 1) - ) - assert len(self.cur.fetchall()) == 1 - - def test_InsertExperimentWithoutOwner(self): - with self.assertRaises(psycopg2.errors.NotNullViolation): - experiment_obj = { - "experiment_name": "control" - } - input_tool.insert_table_record(self.cur, experiment_obj, "experiments") - - def test_InsertExperimentWithOwnerNotFound(self): - with self.assertRaises(psycopg2.errors.ForeignKeyViolation): - experiment_obj = { - "experiment_name": "control", - "owner_id": -1 - } - input_tool.insert_table_record(self.cur, experiment_obj, "experiments") - - def test_InsertExperimentWithSameNameAndOwner(self): - with self.assertRaises(psycopg2.errors.UniqueViolation): - experiment_obj = { - "experiment_name": "experiment_control", - "owner_id": 1 - } - input_tool.insert_table_record(self.cur, experiment_obj, "experiments") - - def test_DeleteExperimentCascades(self): - self.cur.execute( - "DELETE FROM experiments WHERE experiment_id=%s", (12,)) - # find any instance of experiment in 'experiments' - self.cur.execute( - "SELECT (experiment_id) FROM experiments WHERE experiment_id=%s", (12,)) - assert len(self.cur.fetchall()) == 0 - # find any instance of experiment in 'plots' - self.cur.execute( - "SELECT (plot_id) FROM plots WHERE experiment_id=%s", (12,)) - assert len(self.cur.fetchall()) == 0 - - def test_InsertPlotExpected(self): - plot_obj = { - "plot_id": 115, - "experiment_id": 1, - "group_id": 3, - "observation_id": 1, - "variable_id": 2 - } - input_tool.insert_table_record(self.cur, plot_obj, "plots") - self.cur.execute( - "SELECT (plot_id) FROM plots WHERE plot_id=%s", (115,)) - assert len(self.cur.fetchall()) == 1 - - def test_InsertPlotMissingFields(self): - with self.assertRaises(psycopg2.errors.NotNullViolation): - plot_obj = { - "group_id": 1, - "observation_id": 1 - } - input_tool.insert_table_record(self.cur, plot_obj, "plots") - self.tearDown() - self.setUp() - with self.assertRaises(psycopg2.errors.NotNullViolation): - plot_obj = { - "experiment_id": 12, - "observation_id": 1 - } - input_tool.insert_table_record(self.cur, plot_obj, "plots") - self.tearDown() - self.setUp() - with self.assertRaises(psycopg2.errors.NotNullViolation): - plot_obj = { - "experiment_id": 12, - "group_id": 1 - } - input_tool.insert_table_record(self.cur, plot_obj, "plots") - - def test_InsertPlotInvalidFields(self): - plot_obj = { - "experiment_id": 12, - "group_id": 1, - "observation_id": 1, - "variable_id": 1 - } - with self.assertRaises(psycopg2.errors.ForeignKeyViolation): - plot_obj["experiment_id"] = -12 - input_tool.insert_table_record(self.cur, plot_obj, "plots") - self.tearDown() - self.setUp() - with self.assertRaises(psycopg2.errors.ForeignKeyViolation): - plot_obj["experiment_id"] = 12 - plot_obj["group_id"] = -1 - input_tool.insert_table_record(self.cur, plot_obj, "plots") - self.tearDown() - self.setUp() - with self.assertRaises(psycopg2.errors.ForeignKeyViolation): - plot_obj["group_id"] = 1 - plot_obj["observation_id"] = -1 - input_tool.insert_table_record(self.cur, plot_obj, "plots") - self.tearDown() - self.setUp() - with self.assertRaises(psycopg2.errors.ForeignKeyViolation): - plot_obj["observation_id"] = 1 - plot_obj["variable_id"] = -1 - input_tool.insert_table_record(self.cur, plot_obj, "plots") - - def test_InsertObservationExpected(self): - observation_obj = { - "observation_name": "aircraft" - } - input_tool.insert_table_record(self.cur, observation_obj, "observations") - self.cur.execute( - """SELECT (observation_id) FROM observations - WHERE observation_name=%s""", - ("satwind",)) - assert len(self.cur.fetchall()) == 1 - - def test_InsertObservationWithSameName(self): - with self.assertRaises(psycopg2.errors.UniqueViolation): - observation_obj = { - "observation_name": "satwind", - } - input_tool.insert_table_record(self.cur, observation_obj, "observations") - - def test_FetchExistingPlots(self): - # get all amsua_n18 plots in experiment "experiment_iv_1" where the user is thamzey - self.cur.execute("""SELECT plot_id, plots.experiment_id FROM plots - JOIN experiments ON plots.experiment_id = experiments.experiment_id - JOIN observations ON plots.observation_id = observations.observation_id - JOIN variables ON plots.variable_id = variables.variable_id - JOIN owners ON owners.owner_id = experiments.owner_id - WHERE experiments.experiment_name = %s - AND owners.username = %s AND observations.observation_name = %s; """, - ("experiment_iv_1", "thamzey", "amsua_n18")) - plots = self.cur.fetchall() - self.assertTrue(len(plots) == 2) - # checks if plot with plot_id=114 and experiment_id=3 - # and plot_id=323 and experiment_id=3 was found - self.assertTrue((114, 3) in plots) - self.assertTrue((323, 3) in plots) - - def test_FetchNonExistingPlots(self): - # get all satwind plots in experiment "experiment_iv_1" - self.cur.execute("""SELECT plot_id FROM plots - JOIN experiments ON plots.experiment_id = experiments.experiment_id - JOIN observations ON plots.observation_id = observations.observation_id - JOIN variables ON plots.variable_id = variables.variable_id - WHERE experiments.experiment_name = %s - AND observations.observation_name = %s;""", - ("experiment_iv_1", "satwind")) - self.assertTrue(len(self.cur.fetchall()) == 0) - - def test_GroupsRelationWithPlotsAndExperiments(self): - self.cur.execute("""SELECT group_name, plots.plot_id, experiments.experiment_id FROM groups - JOIN plots ON plots.group_id = groups.group_id - JOIN experiments ON experiments.experiment_id = plots.experiment_id - WHERE group_name = %s;""", ("effectiveerror-vs-gsifinalerror",)) - plots, experiments = [], [] - for item in self.cur.fetchall(): - assert len(item) == 3 # sanity checks - assert item[0] == "effectiveerror-vs-gsifinalerror" - plots.append(item[1]) - experiments.append(item[2]) - # check that plots associated with group have unique plot ids (relation) - self.assertEqual(plots, list(set(plots))) - # check that experiments associated with group can be the same (no relation) - self.assertGreaterEqual(len(experiments), len(set(experiments))) - - -unittest.main()