Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix spec updates #617

Merged
merged 2 commits into from
Nov 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion compose/neurosynth_compose/models/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ class SpecificationCondition(BaseMixin, db.Model):
db.Text, db.ForeignKey("conditions.id"), index=True, primary_key=True
)
condition = relationship("Condition", backref=backref("specification_conditions"))
specification = relationship("Specification", backref=backref("specification_conditions"))
specification = relationship(
"Specification", backref=backref("specification_conditions")
)
user_id = db.Column(db.Text, db.ForeignKey("users.external_id"))
user = relationship("User", backref=backref("specification_conditions"))


class Specification(BaseMixin, db.Model):
Expand All @@ -59,6 +63,7 @@ class Specification(BaseMixin, db.Model):
filter = db.Column(db.Text)
weights = association_proxy("specification_conditions", "weight")
conditions = association_proxy("specification_conditions", "condition")
database_studyset = db.Column(db.Text)
corrector = db.Column(db.JSON)
user_id = db.Column(db.Text, db.ForeignKey("users.external_id"))
user = relationship("User", backref=backref("specifications"))
Expand Down
32 changes: 26 additions & 6 deletions compose/neurosynth_compose/resources/analysis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from collections import ChainMap
import pathlib
from operator import itemgetter
Expand Down Expand Up @@ -119,7 +118,10 @@ def update_or_create(cls, data, id=None, commit=True):
only_ids = set(data.keys()) - set(["id"]) == set()

if cls._model is Condition:
record = cls._model.query.filter_by(name=data.get('name')).first() or cls._model()
record = (
cls._model.query.filter_by(name=data.get("name")).first()
or cls._model()
)
if id is None:
record = cls._model()
record.user = current_user
Expand Down Expand Up @@ -149,7 +151,8 @@ def update_or_create(cls, data, id=None, commit=True):

# get nested attributes
nested_keys = [
item for key in cls._nested.keys()
item
for key in cls._nested.keys()
for item in (key if isinstance(key, tuple) else (key,))
]

Expand All @@ -164,25 +167,42 @@ def update_or_create(cls, data, id=None, commit=True):
# Update nested attributes recursively
for field, res_name in cls._nested.items():
field = (field,) if not isinstance(field, tuple) else field
if set(data.keys()).issubset(field):
field = (list(data.keys())[0],)

try:
rec_data = itemgetter(*field)(data)
except KeyError:
rec_data = None

ResCls = globals()[res_name]

if rec_data is not None:
if isinstance(rec_data, tuple):
rec_data = [dict(ChainMap(*rc)) for rc in zip(*rec_data)]
# get ids of existing nested attributes
existing_nested = None
if cls._attribute_name:
existing_nested = getattr(record, cls._attribute_name, None)

if existing_nested and len(existing_nested) == len(rec_data):
_ = [
rd.update({"id": ns.id})
for rd, ns in zip(
rec_data, getattr(record, cls._attribute_name)
)
]
if isinstance(rec_data, list):
nested = [
ResCls.update_or_create(rec, commit=False)
for rec in rec_data
ResCls.update_or_create(rec, commit=False) for rec in rec_data
]
to_commit.extend(nested)
else:
nested = ResCls.update_or_create(rec_data, commit=False)
to_commit.append(nested)
update_field = field if len(field) == 1 else (cls._attribute_name,)
update_field = (
field if not cls._attribute_name else (cls._attribute_name,)
)
for f in update_field:
setattr(record, f, nested)

Expand Down
22 changes: 8 additions & 14 deletions compose/neurosynth_compose/schemas/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,7 @@ class ConditionSchema(Schema):
description = PGSQLString()


class SpecificationConditionSchema(Schema):
id = PGSQLString()
created_at = fields.DateTime()
updated_at = fields.DateTime(allow_none=True)
class SpecificationConditionSchema(BaseSchema):
condition = fields.Pluck(ConditionSchema, "name")
weight = fields.Number()

Expand All @@ -152,7 +149,7 @@ class StudysetReferenceSchema(Schema):
exclude=("snapshot",),
metadata={"pluck": "id"},
many=True,
dump_only=True
dump_only=True,
)


Expand All @@ -165,6 +162,7 @@ class SpecificationSchema(BaseSchema):
mask = PGSQLString(allow_none=True)
transformer = PGSQLString(allow_none=True)
estimator = fields.Nested("EstimatorSchema")
database_studyset = PGSQLString(allow_none=True)
contrast = PGSQLString(allow_none=True)
filter = PGSQLString(allow_none=True)
corrector = fields.Dict(allow_none=True)
Expand All @@ -178,11 +176,7 @@ class SpecificationSchema(BaseSchema):
data_key="conditions",
)
conditions = fields.Pluck(
ConditionSchema,
"name",
many=True,
allow_none=True,
dump_only=True
ConditionSchema, "name", many=True, allow_none=True, dump_only=True
)
weights = fields.List(
fields.Float(),
Expand Down Expand Up @@ -213,7 +207,7 @@ def to_bool(self, data, **kwargs):
output_conditions[i] = True
elif cond.lower() == "false":
output_conditions[i] = False
data['conditions'] = conditions
data["conditions"] = conditions

return data

Expand All @@ -224,10 +218,10 @@ def to_string(self, data, **kwargs):
output_conditions = conditions[:]
for i, cond in enumerate(conditions):
if cond is True:
output_conditions[i] = 'true'
output_conditions[i] = "true"
elif cond is False:
output_conditions[i] = 'false'
data['conditions'] = output_conditions
output_conditions[i] = "false"
data["conditions"] = output_conditions

return data

Expand Down
43 changes: 42 additions & 1 deletion compose/neurosynth_compose/tests/api/test_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_get_specification(session, app, auth_client, user_data):
"corrector": {"type": "FDRCorrector"},
"filter": "eyes",
},
]
],
)
def test_create_and_get_spec(session, app, auth_client, user_data, specification_data):
create_spec = auth_client.post("/api/specifications", data=specification_data)
Expand All @@ -35,3 +35,44 @@ def test_create_and_get_spec(session, app, auth_client, user_data, specification
view_spec = auth_client.get(f"/api/specifications/{create_spec.json['id']}")

assert create_spec.json == view_spec.json


@pytest.mark.parametrize(
"attribute,value",
[
("estimator", {"type": "MKDA"}),
("type", "ibma"),
("conditions", ["yes", "no"]),
("weights", [1, 1]),
("corrector", {"type": "FWECorrector"}),
("filter", "bunny"),
("database_studyset", "neurostore"),
],
)
def test_update_spec(session, app, auth_client, user_data, attribute, value):
specification_data = {
"estimator": {"type": "ALE"},
"type": "cbma",
"conditions": ["open", "closed"],
"weights": [1, -1],
"corrector": {"type": "FDRCorrector"},
"filter": "eyes",
}
create_spec = auth_client.post("/api/specifications", data=specification_data)

assert create_spec.status_code == 200

spec_id = create_spec.json["id"]

update_spec = auth_client.put(
f"/api/specifications/{spec_id}", data={attribute: value}
)
assert update_spec.status_code == 200

get_spec = auth_client.get(f"/api/specifications/{spec_id}")
assert get_spec.status_code == 200

if isinstance(value, list):
assert set(get_spec.json[attribute]) == set(value)
else:
assert get_spec.json[attribute] == value
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@

def test_studyset_references(session, app, auth_client, user_data):
nonnested = auth_client.get("/api/studyset-references?nested=false")
nested = auth_client.get("/api/studyset-references?nested=true")

assert nonnested.status_code == nested.status_code == 200
assert isinstance(nonnested.json['results'][0]['studysets'][0], str)
assert isinstance(nested.json['results'][0]['studysets'][0], dict)
assert isinstance(nonnested.json["results"][0]["studysets"][0], str)
assert isinstance(nested.json["results"][0]["studysets"][0], dict)
Loading