From feda937dacf30e1f6e39c23457d20ab0d74e8cf5 Mon Sep 17 00:00:00 2001
From: James Kent <jamesdkent21@gmail.com>
Date: Wed, 13 Mar 2024 02:15:20 -0500
Subject: [PATCH] [ENH] add annotation-analysis endpoint (#737)

* add annotation-analysis endpoint

* add additional info to annotationanalyses

* run black

* fix loading procedure

* do eager loading

* do not run update for annotationanalysis as well

* switch to main branch
---
 store/neurostore/database.py                  | 20 +++--
 store/neurostore/models/data.py               |  9 ++-
 store/neurostore/openapi                      |  2 +-
 store/neurostore/resources/__init__.py        |  2 +
 store/neurostore/resources/base.py            |  2 +-
 store/neurostore/resources/data.py            | 78 +++++++++++++------
 store/neurostore/resources/utils.py           |  8 +-
 store/neurostore/schemas/data.py              |  3 +-
 .../neurostore/tests/api/test_base_studies.py |  2 +-
 store/neurostore/tests/api/test_crud.py       | 13 +++-
 .../neurostore/tests/api/test_performance.py  |  4 +-
 store/neurostore/tests/conftest.py            |  1 +
 12 files changed, 100 insertions(+), 44 deletions(-)

diff --git a/store/neurostore/database.py b/store/neurostore/database.py
index 9d633d77c..21b5836b6 100644
--- a/store/neurostore/database.py
+++ b/store/neurostore/database.py
@@ -5,17 +5,21 @@
 
 def orjson_serializer(obj):
     """
-        Note that `orjson.dumps()` return byte array,
-        while sqlalchemy expects string, thus `decode()` call.
+    Note that `orjson.dumps()` return byte array,
+    while sqlalchemy expects string, thus `decode()` call.
     """
-    return orjson.dumps(obj, option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NAIVE_UTC).decode()
+    return orjson.dumps(
+        obj, option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NAIVE_UTC
+    ).decode()
 
 
-db = SQLAlchemy(engine_options={
-    "future": True,
-    "json_serializer": orjson_serializer,
-    "json_deserializer": orjson.loads,
-    })
+db = SQLAlchemy(
+    engine_options={
+        "future": True,
+        "json_serializer": orjson_serializer,
+        "json_deserializer": orjson.loads,
+    }
+)
 Base = declarative_base()
 
 
diff --git a/store/neurostore/models/data.py b/store/neurostore/models/data.py
index 471faa37d..9806a2a85 100644
--- a/store/neurostore/models/data.py
+++ b/store/neurostore/models/data.py
@@ -115,7 +115,7 @@ class Annotation(BaseMixin, db.Model):
     )
 
 
-class AnnotationAnalysis(db.Model):
+class AnnotationAnalysis(BaseMixin, db.Model):
     __tablename__ = "annotation_analyses"
     __table_args__ = (
         ForeignKeyConstraint(
@@ -126,22 +126,25 @@ class AnnotationAnalysis(db.Model):
     )
     __mapper_args__ = {"confirm_deleted_rows": False}
 
+    user_id = db.Column(db.Text, db.ForeignKey("users.external_id"), index=True)
     study_id = db.Column(db.Text, nullable=False)
     studyset_id = db.Column(db.Text, nullable=False)
     annotation_id = db.Column(
         db.Text,
         db.ForeignKey("annotations.id", ondelete="CASCADE"),
         index=True,
-        primary_key=True,
     )
     analysis_id = db.Column(
         db.Text,
         db.ForeignKey("analyses.id", ondelete="CASCADE"),
         index=True,
-        primary_key=True,
     )
     note = db.Column(MutableDict.as_mutable(JSONB))
 
+    user = relationship(
+        "User", backref=backref("annotation_analyses", passive_deletes=True)
+    )
+
 
 class BaseStudy(BaseMixin, db.Model):
     __tablename__ = "base_studies"
diff --git a/store/neurostore/openapi b/store/neurostore/openapi
index 2493f75b3..31b93b541 160000
--- a/store/neurostore/openapi
+++ b/store/neurostore/openapi
@@ -1 +1 @@
-Subproject commit 2493f75b3911aa0ede0717b579dfe39126da36b2
+Subproject commit 31b93b5414361124cb66f22c2e6dd9e8fb6cccde
diff --git a/store/neurostore/resources/__init__.py b/store/neurostore/resources/__init__.py
index 8c1f1ce9d..900be4626 100644
--- a/store/neurostore/resources/__init__.py
+++ b/store/neurostore/resources/__init__.py
@@ -1,6 +1,7 @@
 from .data import (
     StudysetsView,
     AnnotationsView,
+    AnnotationAnalysesView,
     BaseStudiesView,
     StudiesView,
     AnalysesView,
@@ -17,6 +18,7 @@
 __all__ = [
     "StudysetsView",
     "AnnotationsView",
+    "AnnotationAnalysesView",
     "BaseStudiesView",
     "StudiesView",
     "AnalysesView",
diff --git a/store/neurostore/resources/base.py b/store/neurostore/resources/base.py
index fb4333852..65e7b1fcf 100644
--- a/store/neurostore/resources/base.py
+++ b/store/neurostore/resources/base.py
@@ -472,7 +472,7 @@ def put(self, id):
 
         try:
             self.update_base_studies(unique_ids.get("base-studies"))
-            if self._model is not Annotation:
+            if self._model is not Annotation and self._model is not AnnotationAnalysis:
                 self.update_annotations(unique_ids.get("annotations"))
         except SQLAlchemyError as e:
             db.session.rollback()
diff --git a/store/neurostore/resources/data.py b/store/neurostore/resources/data.py
index b576e4491..ea0b42291 100644
--- a/store/neurostore/resources/data.py
+++ b/store/neurostore/resources/data.py
@@ -35,7 +35,6 @@
 from ..schemas import (
     BooleanOrString,
     AnalysisConditionSchema,
-    AnnotationAnalysisSchema,
     StudysetStudySchema,
     EntitySchema,
 )
@@ -44,6 +43,7 @@
 __all__ = [
     "StudysetsView",
     "AnnotationsView",
+    "AnnotationAnalysesView",
     "BaseStudiesView",
     "StudiesView",
     "AnalysesView",
@@ -200,10 +200,10 @@ def serialize_records(self, records, args):
 @view_maker
 class AnnotationsView(ObjectView, ListView):
     _view_fields = {**LIST_CLONE_ARGS, "studyset_id": fields.String(load_default=None)}
-    _o2m = {"annotation_analyses": "AnnotationAnalysesResource"}
+    _o2m = {"annotation_analyses": "AnnotationAnalysesView"}
     _m2o = {"studyset": "StudysetsView"}
 
-    _nested = {"annotation_analyses": "AnnotationAnalysesResource"}
+    _nested = {"annotation_analyses": "AnnotationAnalysesView"}
     _linked = {
         "studyset": "StudysetsView",
     }
@@ -255,7 +255,16 @@ def eager_load(self, q, args=None):
             selectinload(Annotation.user)
             .load_only(User.name, User.external_id)
             .options(raiseload("*", sql_only=True)),
-            selectinload(Annotation.annotation_analyses).options(
+            selectinload(Annotation.annotation_analyses)
+            .load_only(
+                AnnotationAnalysis.id,
+                AnnotationAnalysis.analysis_id,
+                AnnotationAnalysis.created_at,
+                AnnotationAnalysis.study_id,
+                AnnotationAnalysis.studyset_id,
+                AnnotationAnalysis.annotation_id,
+            )
+            .options(
                 joinedload(AnnotationAnalysis.analysis)
                 .load_only(Analysis.id, Analysis.name)
                 .options(raiseload("*", sql_only=True)),
@@ -339,8 +348,13 @@ def join_tables(self, q, args):
     def db_validation(self, record, data):
         db_analysis_ids = {aa.analysis_id for aa in record.annotation_analyses}
         data_analysis_ids = {
-            aa["analysis"]["id"] for aa in data.get("annotation_analyses")
+            aa.get("analysis", {}).get("id", "")
+            for aa in data.get("annotation_analyses", [])
         }
+
+        if not data_analysis_ids:
+            return
+
         if db_analysis_ids != data_analysis_ids:
             abort(
                 400,
@@ -779,7 +793,7 @@ class AnalysesView(ObjectView, ListView):
         "images": "ImagesView",
         "points": "PointsView",
         "analysis_conditions": "AnalysisConditionsResource",
-        "annotation_analyses": "AnnotationAnalysesResource",
+        "annotation_analyses": "AnnotationAnalysesView",
     }
     _m2o = {
         "study": "StudiesView",
@@ -794,7 +808,7 @@ class AnalysesView(ObjectView, ListView):
         "study": "StudiesView",
     }
     _linked = {
-        "annotation_analyses": "AnnotationAnalysesResource",
+        "annotation_analyses": "AnnotationAnalysesView",
     }
     _search_fields = ("name", "description")
 
@@ -1087,20 +1101,8 @@ class PointValuesView(ObjectView, ListView):
     }
 
 
-# Utility resources for updating data
-class AnalysisConditionsResource(BaseView):
-    _m2o = {
-        "analysis": "AnalysesView",
-        "condition": "ConditionsView",
-    }
-    _nested = {"condition": "ConditionsView"}
-    _parent = {"analysis": "AnalysesView"}
-    _model = AnalysisConditions
-    _schema = AnalysisConditionSchema
-    _composite_key = {}
-
-
-class AnnotationAnalysesResource(BaseView):
+@view_maker
+class AnnotationAnalysesView(ObjectView, ListView):
     _m2o = {
         "annotation": "AnnotationsView",
         "analysis": "AnalysesView",
@@ -1114,8 +1116,38 @@ class AnnotationAnalysesResource(BaseView):
         "analysis": "AnalysesView",
         "studyset_study": "StudysetStudiesResource",
     }
-    _model = AnnotationAnalysis
-    _schema = AnnotationAnalysisSchema
+
+    def eager_load(self, q, args=None):
+        q = q.options(
+            joinedload(AnnotationAnalysis.analysis)
+            .load_only(Analysis.id, Analysis.name)
+            .options(raiseload("*", sql_only=True)),
+            joinedload(AnnotationAnalysis.studyset_study).options(
+                joinedload(StudysetStudy.study)
+                .load_only(
+                    Study.id,
+                    Study.name,
+                    Study.year,
+                    Study.authors,
+                    Study.publication,
+                )
+                .options(raiseload("*", sql_only=True))
+            ),
+        )
+
+        return q
+
+
+# Utility resources for updating data
+class AnalysisConditionsResource(BaseView):
+    _m2o = {
+        "analysis": "AnalysesView",
+        "condition": "ConditionsView",
+    }
+    _nested = {"condition": "ConditionsView"}
+    _parent = {"analysis": "AnalysesView"}
+    _model = AnalysisConditions
+    _schema = AnalysisConditionSchema
     _composite_key = {}
 
 
diff --git a/store/neurostore/resources/utils.py b/store/neurostore/resources/utils.py
index d6f0ef7c9..9fa0fc689 100644
--- a/store/neurostore/resources/utils.py
+++ b/store/neurostore/resources/utils.py
@@ -29,7 +29,13 @@ def get_current_user():
 
 def view_maker(cls):
     proc_name = cls.__name__.removesuffix("View").removesuffix("Resource")
-    basename = singularize(proc_name, custom={"MetaAnalyses": "MetaAnalysis"})
+    basename = singularize(
+        proc_name,
+        custom={
+            "MetaAnalyses": "MetaAnalysis",
+            "AnnotationAnalyses": "AnnotationAnalysis",
+        },
+    )
 
     class ClassView(cls):
         _model = getattr(models, basename)
diff --git a/store/neurostore/schemas/data.py b/store/neurostore/schemas/data.py
index 87bba963d..96fa8020d 100644
--- a/store/neurostore/schemas/data.py
+++ b/store/neurostore/schemas/data.py
@@ -410,6 +410,7 @@ class Meta:
 
 
 class AnnotationAnalysisSchema(BaseSchema):
+    id = fields.String(metadata={"info_field": True, "id_field": True})
     note = fields.Dict()
     annotation = StringOrNested("AnnotationSchema", load_only=True)
     analysis_id = fields.String(
@@ -436,7 +437,7 @@ class AnnotationAnalysisSchema(BaseSchema):
 
     @post_load
     def add_id(self, data, **kwargs):
-        if isinstance(data["analysis_id"], str):
+        if isinstance(data.get("analysis_id"), str):
             data["analysis"] = {"id": data.pop("analysis_id")}
         if isinstance(data.get("study_id"), str) and isinstance(
             data.get("studyset_id"), str
diff --git a/store/neurostore/tests/api/test_base_studies.py b/store/neurostore/tests/api/test_base_studies.py
index a14267e52..1e998ca4e 100644
--- a/store/neurostore/tests/api/test_base_studies.py
+++ b/store/neurostore/tests/api/test_base_studies.py
@@ -28,7 +28,7 @@ def test_post_list_of_studies(auth_client, ingest_neuroquery):
             "doi": "",
             "pmid": "",
             "name": "no ids",
-        }
+        },
     ]
 
     result = auth_client.post("/api/base-studies/", data=test_input)
diff --git a/store/neurostore/tests/api/test_crud.py b/store/neurostore/tests/api/test_crud.py
index 85da704a1..65997e7ad 100644
--- a/store/neurostore/tests/api/test_crud.py
+++ b/store/neurostore/tests/api/test_crud.py
@@ -6,6 +6,7 @@
     BaseStudy,
     Study,
     Annotation,
+    AnnotationAnalysis,
     Analysis,
     Condition,
     Image,
@@ -16,6 +17,7 @@
     BaseStudySchema,
     StudySchema,
     AnnotationSchema,
+    AnnotationAnalysisSchema,
     AnalysisSchema,
     ConditionSchema,
     ImageSchema,
@@ -28,7 +30,7 @@
     "endpoint,model,schema",
     [
         ("studysets", Studyset, StudysetSchema),
-        # ("annotations", Annotation, AnnotationSchema), FIX
+        ("annotations", Annotation, AnnotationSchema),
         ("base-studies", BaseStudy, BaseStudySchema),
         ("studies", Study, StudySchema),
         ("analyses", Analysis, AnalysisSchema),
@@ -74,6 +76,7 @@ def test_create(auth_client, user_data, endpoint, model, schema, session):
     [
         ("studysets", Studyset, StudysetSchema),
         ("annotations", Annotation, AnnotationSchema),
+        ("annotation-analyses", AnnotationAnalysis, AnnotationAnalysisSchema),
         ("base-studies", BaseStudy, BaseStudySchema),
         ("studies", Study, StudySchema),
         ("analyses", Analysis, AnalysisSchema),
@@ -114,7 +117,13 @@ def test_read(auth_client, user_data, endpoint, model, schema, session):
     "endpoint,model,schema,update",
     [
         ("studysets", Studyset, StudysetSchema, {"description": "mine"}),
-        # ("annotations", Annotation, AnnotationSchema, {'description': 'mine'}), FIX
+        ("annotations", Annotation, AnnotationSchema, {"description": "mine"}),
+        (
+            "annotation-analyses",
+            AnnotationAnalysis,
+            AnnotationAnalysisSchema,
+            {"note": {"new": "note"}},
+        ),
         ("base-studies", BaseStudy, BaseStudySchema, {"description": "mine"}),
         ("studies", Study, StudySchema, {"description": "mine"}),
         ("analyses", Analysis, AnalysisSchema, {"description": "mine"}),
diff --git a/store/neurostore/tests/api/test_performance.py b/store/neurostore/tests/api/test_performance.py
index b6ed32893..1232f151a 100644
--- a/store/neurostore/tests/api/test_performance.py
+++ b/store/neurostore/tests/api/test_performance.py
@@ -49,9 +49,7 @@ def test_mass_creation(auth_client, session):
             "analyses": [
                 {
                     "name": f"analysis{i}",
-                    "points": [
-                        {"x": 0, "y": 0, "z": 0, "space": "mni", "order": 1}
-                    ],
+                    "points": [{"x": 0, "y": 0, "z": 0, "space": "mni", "order": 1}],
                 }
             ],
         }
diff --git a/store/neurostore/tests/conftest.py b/store/neurostore/tests/conftest.py
index 45048c308..58f77bfa2 100644
--- a/store/neurostore/tests/conftest.py
+++ b/store/neurostore/tests/conftest.py
@@ -517,6 +517,7 @@ def user_data(session, mock_add_users):
                         annotation=annotation,
                         analysis=analysis,
                         note={"food": "bar"},
+                        user=user,
                     )
                     annotation.annotation_analyses.append(aa)