diff --git a/backend/api/projects/resources.py b/backend/api/projects/resources.py index 31ff46c2c7..17d2c8e979 100644 --- a/backend/api/projects/resources.py +++ b/backend/api/projects/resources.py @@ -437,8 +437,9 @@ async def patch( project_dto.project_id = project_id try: - await ProjectAdminService.update_project(project_dto, user.id, db) - return JSONResponse(content={"Status": "Updated"}, status_code=200) + async with db.transaction(): + await ProjectAdminService.update_project(project_dto, user.id, db) + return JSONResponse(content={"Status": "Updated"}, status_code=200) except InvalidGeoJson as e: return JSONResponse(content={"Invalid GeoJson": str(e)}, status_code=400) except ProjectAdminServiceError as e: diff --git a/backend/models/dtos/campaign_dto.py b/backend/models/dtos/campaign_dto.py index 65e6df889f..78e9621afc 100644 --- a/backend/models/dtos/campaign_dto.py +++ b/backend/models/dtos/campaign_dto.py @@ -34,9 +34,10 @@ class CampaignDTO(BaseModel): logo: Optional[str] = None url: Optional[str] = None description: Optional[str] = None - organisations: List[OrganisationDTO] = Field( - default=None, serialization_alias="organisations" - ) + organisations: List[OrganisationDTO] = Field(default=None, alias="organisations") + + class Config: + populate_by_name = True class CampaignProjectDTO(BaseModel): diff --git a/backend/models/dtos/interests_dto.py b/backend/models/dtos/interests_dto.py index f1f492509b..c2ea00dcae 100644 --- a/backend/models/dtos/interests_dto.py +++ b/backend/models/dtos/interests_dto.py @@ -17,15 +17,18 @@ class InterestDTO(BaseModel): id: Optional[int] = None name: Optional[str] = Field(default=None, min_length=1) user_selected: Optional[bool] = Field( - serialization_alias="userSelected", default=None, none_if_default=True + alias="userSelected", default=None, none_if_default=True ) count_projects: Optional[int] = Field( - serialize=False, serialization_alias="countProjects", default=None + serialize=False, alias="countProjects", default=None ) count_users: Optional[int] = Field( - serialize=False, serialization_alias="countUsers", default=None + serialize=False, alias="countUsers", default=None ) + class Config: + populate_by_name = True + class ListInterestDTO(BaseModel): id: Optional[int] = None diff --git a/backend/models/dtos/message_dto.py b/backend/models/dtos/message_dto.py index 0fe1ad482f..2b594c8cbc 100644 --- a/backend/models/dtos/message_dto.py +++ b/backend/models/dtos/message_dto.py @@ -7,18 +7,16 @@ class MessageDTO(BaseModel): """DTO used to define a message that will be sent to a user""" - message_id: int = Field(None, serialization_alias="message_id") - subject: str = Field(None, serialization_alias="subject") - message: str = Field(None, serialization_alias="message") - from_username: Optional[str] = Field("", serialization_alias="fromUsername") - display_picture_url: Optional[str] = Field( - "", serialization_alias="displayPictureUrl" - ) - project_id: Optional[int] = Field(None, serialization_alias="projectId") - project_title: Optional[str] = Field(None, serialization_alias="projectTitle") - task_id: Optional[int] = Field(None, serialization_alias="taskId") - message_type: Optional[str] = Field(None, serialization_alias="message_type") - sent_date: datetime = Field(None, serialization_alias="sentDate") + message_id: int = Field(None, alias="message_id") + subject: str = Field(None, alias="subject") + message: str = Field(None, alias="message") + from_username: Optional[str] = Field("", alias="fromUsername") + display_picture_url: Optional[str] = Field("", alias="displayPictureUrl") + project_id: Optional[int] = Field(None, alias="projectId") + project_title: Optional[str] = Field(None, alias="projectTitle") + task_id: Optional[int] = Field(None, alias="taskId") + message_type: Optional[str] = Field(None, alias="message_type") + sent_date: datetime = Field(None, alias="sentDate") read: bool = False class Config: diff --git a/backend/models/dtos/organisation_dto.py b/backend/models/dtos/organisation_dto.py index 947690bc38..e4ea920d0c 100644 --- a/backend/models/dtos/organisation_dto.py +++ b/backend/models/dtos/organisation_dto.py @@ -19,35 +19,42 @@ def is_known_organisation_type(value): class OrganisationManagerDTO(BaseModel): username: Optional[str] = None - picture_url: Optional[str] = Field(None, serialization_alias="pictureUrl") + picture_url: Optional[str] = Field(None, alias="pictureUrl") + + class Config: + populate_by_name = True class OrganisationTeamsDTO(BaseModel): - team_id: Optional[int] = Field(None, serialization_alias="teamId") + team_id: Optional[int] = Field(None, alias="teamId") name: Optional[str] = None description: Optional[str] = None - join_method: Optional[str] = Field(None, serialization_alias="joinMethod") + join_method: Optional[str] = Field(None, alias="joinMethod") visibility: Optional[str] = None members: List[Dict[str, Optional[str]]] = Field(default=[]) + class Config: + populate_by_name = True + class OrganisationDTO(BaseModel): - organisation_id: Optional[int] = Field(None, serialization_alias="organisationId") + organisation_id: Optional[int] = Field(None, alias="organisationId") managers: Optional[List[OrganisationManagerDTO]] = None name: Optional[str] = None slug: Optional[str] = None logo: Optional[str] = None description: Optional[str] = None url: Optional[str] = None - is_manager: Optional[bool] = Field(None, serialization_alias="isManager") + is_manager: Optional[bool] = Field(None, alias="isManager") projects: Optional[List[str]] = Field(default=[], alias="projects") teams: List[OrganisationTeamsDTO] = None campaigns: Optional[List[List[str]]] = None stats: Optional[OrganizationStatsDTO] = None type: Optional[str] = Field(None) - subscription_tier: Optional[int] = Field( - None, serialization_alias="subscriptionTier" - ) + subscription_tier: Optional[int] = Field(None, alias="subscriptionTier") + + class Config: + populate_by_name = True @field_validator("type", mode="before") def validate_type(cls, value): @@ -67,7 +74,7 @@ def __init__(self): class NewOrganisationDTO(BaseModel): """Describes a JSON model to create a new organisation""" - organisation_id: Optional[int] = Field(None, serialization_alias="organisationId") + organisation_id: Optional[int] = Field(None, alias="organisationId") managers: List[str] name: str slug: Optional[str] = None @@ -75,9 +82,10 @@ class NewOrganisationDTO(BaseModel): description: Optional[str] = None url: Optional[str] = None type: str - subscription_tier: Optional[int] = Field( - None, serialization_alias="subscriptionTier" - ) + subscription_tier: Optional[int] = Field(None, alias="subscriptionTier") + + class Config: + populate_by_name = True @field_validator("type", mode="before") @classmethod @@ -94,7 +102,7 @@ def validate_type(cls, value: Optional[str]) -> Optional[str]: class UpdateOrganisationDTO(OrganisationDTO): - organisation_id: Optional[int] = Field(None, serialization_alias="organisationId") + organisation_id: Optional[int] = Field(None, alias="organisationId") managers: List[str] = Field(default=[]) name: Optional[str] = None slug: Optional[str] = None @@ -103,6 +111,9 @@ class UpdateOrganisationDTO(OrganisationDTO): url: Optional[str] = None type: Optional[str] = None + class Config: + populate_by_name = True + @field_validator("type", mode="before") @classmethod def validate_type(cls, value: Optional[str]) -> Optional[str]: diff --git a/backend/models/dtos/project_dto.py b/backend/models/dtos/project_dto.py index c226c26385..82a5ad7a17 100644 --- a/backend/models/dtos/project_dto.py +++ b/backend/models/dtos/project_dto.py @@ -1,25 +1,26 @@ # from schematics import Model # from schematics.exceptions import ValidationError -from backend.models.dtos.task_annotation_dto import TaskAnnotationDTO +from datetime import date, datetime +from typing import Dict, List, Optional, Union + +from fastapi import HTTPException +from pydantic import BaseModel, Field + +from backend.models.dtos.campaign_dto import CampaignDTO +from backend.models.dtos.interests_dto import InterestDTO from backend.models.dtos.stats_dto import Pagination +from backend.models.dtos.task_annotation_dto import TaskAnnotationDTO from backend.models.dtos.team_dto import ProjectTeamDTO -from backend.models.dtos.interests_dto import InterestDTO from backend.models.postgis.statuses import ( - ProjectStatus, - ProjectPriority, - MappingTypes, - TaskCreationMode, Editors, MappingPermission, - ValidationPermission, + MappingTypes, ProjectDifficulty, + ProjectPriority, + ProjectStatus, + TaskCreationMode, + ValidationPermission, ) -from backend.models.dtos.campaign_dto import CampaignDTO -from pydantic import BaseModel, Field -from typing import Dict, List, Optional, Union -from datetime import datetime -from datetime import date -from fastapi import HTTPException def is_known_project_status(value: str) -> str: @@ -178,6 +179,9 @@ class DraftProjectDTO(BaseModel): has_arbitrary_tasks: bool = Field(False, alias="arbitraryTasks") user_id: int = Field(None) + class Config: + populate_by_name = True + class ProjectInfoDTO(BaseModel): """Contains the localized project info""" @@ -241,7 +245,7 @@ class ProjectDTO(BaseModel): country_tag: Optional[List[str]] = Field(None, alias="countryTag") license_id: Optional[int] = Field(None, alias="licenseId") allowed_usernames: Optional[List[str]] = Field(default=[], alias="allowedUsernames") - priority_areas: Optional[Dict] = Field(None, alias="priorityAreas") + priority_areas: Optional[List[Dict]] = Field(None, alias="priorityAreas") created: Optional[datetime] = None last_updated: Optional[datetime] = Field(None, alias="lastUpdated") author: Optional[str] = None @@ -407,9 +411,10 @@ class ProjectSearchBBoxDTO(BaseModel): bbox: List[float] = Field(..., min_items=4, max_items=4) input_srid: int = Field(..., choices=[4326]) preferred_locale: Optional[str] = Field(default="en") - project_author: Optional[int] = Field( - default=None, serialization_alias="projectAuthor" - ) + project_author: Optional[int] = Field(default=None, alias="projectAuthor") + + class Config: + populate_by_name = True class ListSearchResultDTO(BaseModel): @@ -431,6 +436,9 @@ class ListSearchResultDTO(BaseModel): total_contributors: Optional[int] = Field(alias="totalContributors", default=None) country: Optional[str] = Field(default="", serialize=False) + class Config: + populate_by_name = True + # class ProjectSearchResultsDTO(BaseModel): # map_results: Optional[List] = [] @@ -468,6 +476,9 @@ class ProjectComment(BaseModel): user_name: str = Field(alias="userName") task_id: int = Field(alias="taskId") + class Config: + populate_by_name = True + class ProjectCommentsDTO(BaseModel): """Contains all comments on a project""" @@ -496,85 +507,50 @@ class ProjectContribsDTO(BaseModel): class ProjectSummary(BaseModel): - project_id: int = Field(..., serialization_alias="projectId") - default_locale: Optional[str] = Field(None, serialization_alias="defaultLocale") + project_id: int = Field(..., alias="projectId") + default_locale: Optional[str] = Field(None, alias="defaultLocale") author: Optional[str] = None created: Optional[datetime] = None - due_date: Optional[datetime] = Field(None, serialization_alias="dueDate") - last_updated: Optional[datetime] = Field(None, serialization_alias="lastUpdated") - priority: Optional[str] = Field(None, serialization_alias="projectPriority") + due_date: Optional[datetime] = Field(None, alias="dueDate") + last_updated: Optional[datetime] = Field(None, alias="lastUpdated") + priority: Optional[str] = Field(None, alias="projectPriority") campaigns: List[CampaignDTO] = Field(default_factory=list) organisation: Optional[int] = None - organisation_name: Optional[str] = Field( - None, serialization_alias="organisationName" - ) - organisation_slug: Optional[str] = Field( - None, serialization_alias="organisationSlug" - ) - organisation_logo: Optional[str] = Field( - None, serialization_alias="organisationLogo" - ) - country_tag: List[str] = Field( - default_factory=list, serialization_alias="countryTag" - ) - osmcha_filter_id: Optional[str] = Field(None, serialization_alias="osmchaFilterId") - mapping_types: List[str] = Field( - default_factory=list, serialization_alias="mappingTypes" - ) - changeset_comment: Optional[str] = Field( - None, serialization_alias="changesetComment" - ) - percent_mapped: Optional[int] = Field(None, serialization_alias="percentMapped") - percent_validated: Optional[int] = Field( - None, serialization_alias="percentValidated" - ) - percent_bad_imagery: Optional[int] = Field( - None, serialization_alias="percentBadImagery" - ) - aoi_centroid: Optional[Union[dict, None]] = Field( - None, serialization_alias="aoiCentroid" - ) - difficulty: Optional[str] = Field(None, serialization_alias="difficulty") - mapping_permission: Optional[int] = Field( - None, serialization_alias="mappingPermission" - ) - validation_permission: Optional[int] = Field( - None, serialization_alias="validationPermission" - ) - allowed_usernames: List[str] = Field( - default_factory=list, serialization_alias="allowedUsernames" - ) + organisation_name: Optional[str] = Field(None, alias="organisationName") + organisation_slug: Optional[str] = Field(None, alias="organisationSlug") + organisation_logo: Optional[str] = Field(None, alias="organisationLogo") + country_tag: List[str] = Field(default_factory=list, alias="countryTag") + osmcha_filter_id: Optional[str] = Field(None, alias="osmchaFilterId") + mapping_types: List[str] = Field(default_factory=list, alias="mappingTypes") + changeset_comment: Optional[str] = Field(None, alias="changesetComment") + percent_mapped: Optional[int] = Field(None, alias="percentMapped") + percent_validated: Optional[int] = Field(None, alias="percentValidated") + percent_bad_imagery: Optional[int] = Field(None, alias="percentBadImagery") + aoi_centroid: Optional[Union[dict, None]] = Field(None, alias="aoiCentroid") + difficulty: Optional[str] = Field(None, alias="difficulty") + mapping_permission: Optional[int] = Field(None, alias="mappingPermission") + validation_permission: Optional[int] = Field(None, alias="validationPermission") + allowed_usernames: List[str] = Field(default_factory=list, alias="allowedUsernames") random_task_selection_enforced: bool = Field( - default=False, serialization_alias="enforceRandomTaskSelection" - ) - private: Optional[bool] = Field(None, serialization_alias="private") - allowed_users: List[str] = Field( - default_factory=list, serialization_alias="allowedUsernames" - ) - project_teams: List[ProjectTeamDTO] = Field( - default_factory=list, serialization_alias="teams" - ) - project_info: Optional[ProjectInfoDTO] = Field( - None, serialization_alias="projectInfo" - ) - short_description: Optional[str] = Field( - None, serialization_alias="shortDescription" + default=False, alias="enforceRandomTaskSelection" ) + private: Optional[bool] = Field(None, alias="private") + allowed_users: List[str] = Field(default_factory=list, alias="allowedUsernames") + project_teams: List[ProjectTeamDTO] = Field(default_factory=list, alias="teams") + project_info: Optional[ProjectInfoDTO] = Field(None, alias="projectInfo") + short_description: Optional[str] = Field(None, alias="shortDescription") status: Optional[str] = None imagery: Optional[str] = None - license_id: Optional[int] = Field(None, serialization_alias="licenseId") - id_presets: List[str] = Field(default_factory=list, serialization_alias="idPresets") - extra_id_params: Optional[str] = Field(None, serialization_alias="extraIdParams") - rapid_power_user: bool = Field(default=False, serialization_alias="rapidPowerUser") - mapping_editors: List[str] = Field( - ..., min_items=1, serialization_alias="mappingEditors" - ) - validation_editors: List[str] = Field( - ..., min_items=1, serialization_alias="validationEditors" - ) - custom_editor: Optional[CustomEditorDTO] = Field( - None, serialization_alias="customEditor" - ) + license_id: Optional[int] = Field(None, alias="licenseId") + id_presets: List[str] = Field(default_factory=list, alias="idPresets") + extra_id_params: Optional[str] = Field(None, alias="extraIdParams") + rapid_power_user: bool = Field(default=False, alias="rapidPowerUser") + mapping_editors: List[str] = Field(..., min_items=1, alias="mappingEditors") + validation_editors: List[str] = Field(..., min_items=1, alias="validationEditors") + custom_editor: Optional[CustomEditorDTO] = Field(None, alias="customEditor") + + class Config: + populate_by_name = True # TODO: Make Validators work. diff --git a/backend/models/dtos/team_dto.py b/backend/models/dtos/team_dto.py index f94616fa5d..900f1dc1c5 100644 --- a/backend/models/dtos/team_dto.py +++ b/backend/models/dtos/team_dto.py @@ -1,12 +1,14 @@ +from typing import List, Optional + +from fastapi import HTTPException +from pydantic import BaseModel, Field, field_validator + from backend.models.dtos.stats_dto import Pagination from backend.models.postgis.statuses import ( + TeamJoinMethod, TeamMemberFunctions, TeamVisibility, - TeamJoinMethod, ) -from pydantic import BaseModel, Field, field_validator -from typing import List, Optional -from fastapi import HTTPException def validate_team_visibility(value: str) -> str: @@ -63,9 +65,9 @@ class TeamMembersDTO(BaseModel): function: str active: bool join_request_notifications: bool = Field( - default=False, serialization_alias="joinRequestNotifications" + default=False, alias="joinRequestNotifications" ) - picture_url: Optional[str] = Field(None, serialization_alias="pictureUrl") + picture_url: Optional[str] = Field(None, alias="pictureUrl") @field_validator("function") def validate_function(cls, value): @@ -81,24 +83,28 @@ class TeamProjectDTO(BaseModel): class ProjectTeamDTO(BaseModel): - team_id: int = Field(None, serialization_alias="teamId") - team_name: str = Field(None, serialization_alias="name") - role: int = Field(None) + """Describes a JSON model to create a project team""" + + team_id: int = Field(alias="teamId") + team_name: str = Field(default=None, alias="name") + role: str = Field() + + class Config: + populate_by_name = True + use_enum_values = True class TeamDetailsDTO(BaseModel): """Pydantic model equivalent of the original TeamDetailsDTO""" - team_id: Optional[int] = Field(None, serialization_alias="teamId") + team_id: Optional[int] = Field(None, alias="teamId") organisation_id: int organisation: str - organisation_slug: Optional[str] = Field( - None, serialization_alias="organisationSlug" - ) + organisation_slug: Optional[str] = Field(None, alias="organisationSlug") name: str logo: Optional[str] = None description: Optional[str] = None - join_method: str = Field(serialization_alias="joinMethod") + join_method: str = Field(alias="joinMethod") visibility: str is_org_admin: bool = Field(False) is_general_admin: bool = Field(False) @@ -113,21 +119,24 @@ def validate_join_method(cls, value): def validate_visibility(cls, value): return validate_team_visibility(value) + class Config: + populate_by_name = True + class TeamDTO(BaseModel): """Describes JSON model for a team""" - team_id: Optional[int] = Field(None, serialization_alias="teamId") - organisation_id: int = Field(None, serialization_alias="organisation_id") - organisation: str = Field(None, serialization_alias="organisation") - name: str = Field(None, serialization_alias="name") + team_id: Optional[int] = Field(None, alias="teamId") + organisation_id: int = Field(None, alias="organisationId") + organisation: str = Field(None, alias="organisation") + name: str = Field(None, alias="name") logo: Optional[str] = None description: Optional[str] = None - join_method: str = Field(None, serialization_alias="joinMethod") - visibility: str = Field(None, serialization_alias="visibility") + join_method: str = Field(None, alias="joinMethod") + visibility: str = Field(None, alias="visibility") members: Optional[List[TeamMembersDTO]] = None - members_count: Optional[int] = Field(None, serialization_alias="membersCount") - managers_count: Optional[int] = Field(None, serialization_alias="managersCount") + members_count: Optional[int] = Field(None, alias="membersCount") + managers_count: Optional[int] = Field(None, alias="managersCount") @field_validator("join_method") def validate_join_method(cls, value): @@ -137,6 +146,9 @@ def validate_join_method(cls, value): def validate_visibility(cls, value): return validate_team_visibility(value) + class Config: + populate_by_name = True + class TeamsListDTO(BaseModel): def __init__(self): @@ -181,6 +193,9 @@ def validate_join_method(cls, value): def validate_visibility(cls, value): return validate_team_visibility(value) + class Config: + populate_by_name = True + class UpdateTeamDTO(BaseModel): """Describes a JSON model to update a team""" @@ -204,23 +219,25 @@ def validate_join_method(cls, value): def validate_visibility(cls, value): return validate_team_visibility(value) + class Config: + populate_by_name = True + class TeamSearchDTO(BaseModel): """Describes a JSON model to search for a team""" - user_id: Optional[float] = Field(None, serialization_alias="userId") - organisation: Optional[int] = Field(None, serialization_alias="organisation") - team_name: Optional[str] = Field(None, serialization_alias="team_name") - omit_members: Optional[bool] = Field(False, serialization_alias="omitMemberList") - full_members_list: Optional[bool] = Field( - True, serialization_alias="fullMemberList" - ) - member: Optional[float] = Field(None, serialization_alias="member") - manager: Optional[float] = Field(None, serialization_alias="manager") - team_role: Optional[str] = Field(None, serialization_alias="team_role") - member_request: Optional[float] = Field( - None, aliserialization_aliasas="member_request" - ) - paginate: Optional[bool] = Field(False, serialization_alias="paginate") - page: Optional[int] = Field(1, serialization_alias="page") - per_page: Optional[int] = Field(10, serialization_alias="perPage") + user_id: Optional[float] = Field(None, alias="userId") + organisation: Optional[int] = Field(None, alias="organisation") + team_name: Optional[str] = Field(None, alias="team_name") + omit_members: Optional[bool] = Field(False, alias="omitMemberList") + full_members_list: Optional[bool] = Field(True, alias="fullMemberList") + member: Optional[float] = Field(None, alias="member") + manager: Optional[float] = Field(None, alias="manager") + team_role: Optional[str] = Field(None, alias="team_role") + member_request: Optional[float] = Field(None, alias="member_request") + paginate: Optional[bool] = Field(False, alias="paginate") + page: Optional[int] = Field(1, alias="page") + per_page: Optional[int] = Field(10, alias="perPage") + + class Config: + populate_by_name = True diff --git a/backend/models/dtos/user_dto.py b/backend/models/dtos/user_dto.py index 02ff927ec4..17d59aa160 100644 --- a/backend/models/dtos/user_dto.py +++ b/backend/models/dtos/user_dto.py @@ -220,8 +220,11 @@ class ProjectParticipantUser(BaseModel): """Describes a user who has participated in a project""" username: str - project_id: float = Field(serialization_alias="projectId") - is_participant: bool = Field(serialization_alias="isParticipant") + project_id: float = Field(alias="projectId") + is_participant: bool = Field(alias="isParticipant") + + class Config: + populate_by_name = True class UserSearchDTO(BaseModel): diff --git a/backend/models/postgis/priority_area.py b/backend/models/postgis/priority_area.py index 966542491e..249a87bbd5 100644 --- a/backend/models/postgis/priority_area.py +++ b/backend/models/postgis/priority_area.py @@ -27,12 +27,41 @@ class PriorityArea(Base): id = Column(Integer, primary_key=True) geometry = Column(Geometry("POLYGON", srid=4326)) + # @classmethod + # async def from_dict(cls, area_poly: dict, db: Database): + # """Create a new Priority Area from dictionary""" + # pa_geojson = geojson.loads(json.dumps(area_poly)) + + # if type(pa_geojson) is not geojson.Polygon: + # raise InvalidGeoJson("Priority Areas must be supplied as Polygons") + + # if not pa_geojson.is_valid: + # raise InvalidGeoJson( + # "Priority Area: Invalid Polygon - " + ", ".join(pa_geojson.errors()) + # ) + + # pa = cls() + # valid_geojson = geojson.dumps(pa_geojson) + # query = """ + # SELECT ST_AsText( + # ST_SetSRID( + # ST_GeomFromGeoJSON(:geojson), 4326 + # ) + # ) AS geometry_wkt; + # """ + # result = await db.fetch_one(query=query, values={"geojson": valid_geojson}) + # pa.geometry = result["geometry_wkt"] if result else None + # return pa + @classmethod async def from_dict(cls, area_poly: dict, db: Database): - """Create a new Priority Area from dictionary""" + """Create a new Priority Area from dictionary and insert into the database.""" + + # Load GeoJSON from the dictionary pa_geojson = geojson.loads(json.dumps(area_poly)) - if type(pa_geojson) is not geojson.Polygon: + # Ensure it's a valid Polygon + if not isinstance(pa_geojson, geojson.Polygon): raise InvalidGeoJson("Priority Areas must be supplied as Polygons") if not pa_geojson.is_valid: @@ -40,18 +69,39 @@ async def from_dict(cls, area_poly: dict, db: Database): "Priority Area: Invalid Polygon - " + ", ".join(pa_geojson.errors()) ) - pa = cls() + # Convert the GeoJSON into WKT format using a raw SQL query valid_geojson = geojson.dumps(pa_geojson) - query = """ + geo_query = """ SELECT ST_AsText( ST_SetSRID( ST_GeomFromGeoJSON(:geojson), 4326 ) ) AS geometry_wkt; """ - result = await db.fetch_one(query=query, values={"geojson": valid_geojson}) - pa.geometry = result["geometry_wkt"] if result else None - return pa + result = await db.fetch_one(query=geo_query, values={"geojson": valid_geojson}) + geometry_wkt = result["geometry_wkt"] if result else None + + if not geometry_wkt: + raise InvalidGeoJson("Failed to create geometry from the given GeoJSON") + + # Insert the new Priority Area into the database and return the inserted ID + insert_query = """ + INSERT INTO priority_areas (geometry) + VALUES (ST_GeomFromText(:geometry, 4326)) + RETURNING id; + """ + insert_result = await db.fetch_one( + query=insert_query, values={"geometry": geometry_wkt} + ) + + if insert_result: + # Assign the ID and geometry to the PriorityArea object + pa = cls() + pa.id = insert_result["id"] + pa.geometry = geometry_wkt + return pa + else: + raise Exception("Failed to insert Priority Area") def get_as_geojson(self): """Helper to translate geometry back to a GEOJson Poly""" diff --git a/backend/models/postgis/project.py b/backend/models/postgis/project.py index 78d3f6b5a0..7cdb71414d 100644 --- a/backend/models/postgis/project.py +++ b/backend/models/postgis/project.py @@ -487,7 +487,7 @@ async def update(self, project_dto: ProjectDTO, db: Database): self.private = project_dto.private self.difficulty = ProjectDifficulty[project_dto.difficulty.upper()].value self.changeset_comment = project_dto.changeset_comment - self.due_date = project_dto.due_date + self.due_date = project_dto.due_date.replace(tzinfo=None) self.imagery = project_dto.imagery self.josm_preset = project_dto.josm_preset self.id_presets = project_dto.id_presets @@ -545,14 +545,17 @@ async def update(self, project_dto: ProjectDTO, db: Database): # Update teams and projects relationship. self.teams = [] if hasattr(project_dto, "project_teams") and project_dto.project_teams: + await db.execute( + delete(ProjectTeams).where(ProjectTeams.project_id == self.id) + ) for team_dto in project_dto.project_teams: - team = Team.get(team_dto.team_id, db) - + team = await Team.get(team_dto.team_id, db) if team is None: raise NotFound(sub_code="TEAM_NOT_FOUND", team_id=team_dto.team_id) - role = TeamRoles[team_dto.role].value - project_team = ProjectTeams(project=self, team=team, role=role) + project_team = ProjectTeams( + project_id=self.id, team_id=team.id, role=role + ) await project_team.create(db) # Set Project Info for all returned locales @@ -570,11 +573,22 @@ async def update(self, project_dto: ProjectDTO, db: Database): else: await ProjectInfo.update_from_dto(ProjectInfo(**project_info), dto, db) - self.priority_areas = [] # Always clear Priority Area prior to updating + # Always clear Priority Area prior to updating + if project_dto.priority_areas: + await Project.clear_existing_priority_areas(db, self.id) for priority_area in project_dto.priority_areas: - pa = PriorityArea.from_dict(priority_area) - self.priority_areas.append(pa) + pa = await PriorityArea.from_dict(priority_area, db) + # Link project and priority area in the database + if pa and pa.id: + link_query = """ + INSERT INTO project_priority_areas (project_id, priority_area_id) + VALUES (:project_id, :priority_area_id) + """ + await db.execute( + query=link_query, + values={"project_id": self.id, "priority_area_id": pa.id}, + ) if project_dto.custom_editor: if not self.custom_editor: @@ -593,15 +607,51 @@ async def update(self, project_dto: ProjectDTO, db: Database): # handle campaign update try: new_ids = [c.id for c in project_dto.campaigns] - new_ids.sort() except TypeError: new_ids = [] - current_ids = [c.id for c in self.campaign] - current_ids.sort() - if new_ids != current_ids: - self.campaign = await db.fetch_all( - select(Campaign).filter(Campaign.id.in_(new_ids)) - ) + + query = """ + SELECT campaign_id + FROM campaign_projects + WHERE project_id = :project_id + """ + campaign_results = await db.fetch_all( + query, values={"project_id": project_dto.project_id} + ) + current_ids = [c.campaign_id for c in campaign_results] + + new_set = set(new_ids) + current_set = set(current_ids) + + if new_set != current_set: + to_add = new_set - current_set + to_remove = current_set - new_set + if to_remove: + await db.execute( + """ + DELETE FROM campaign_projects + WHERE project_id = :project_id + AND campaign_id = ANY(:to_remove) + """, + values={ + "project_id": project_dto.project_id, + "to_remove": list(to_remove), + }, + ) + + if to_add: + insert_query = """ + INSERT INTO campaign_projects (project_id, campaign_id) + VALUES (:project_id, :campaign_id) + """ + for campaign_id in to_add: + await db.execute( + insert_query, + values={ + "project_id": project_dto.project_id, + "campaign_id": campaign_id, + }, + ) if project_dto.mapping_permission: self.mapping_permission = MappingPermission[ @@ -615,16 +665,53 @@ async def update(self, project_dto: ProjectDTO, db: Database): # handle interests update try: - new_ids = [c.id for c in project_dto.interests] - new_ids.sort() + new_interest_ids = [i.id for i in project_dto.interests] except TypeError: - new_ids = [] - current_ids = [c.id for c in self.interests] - current_ids.sort() - if new_ids != current_ids: - self.interests = await db.fetch_all( - select(Interest).filter(Interest.id.in_(new_ids)) - ) + new_interest_ids = [] + + interest_query = """ + SELECT interest_id + FROM project_interests + WHERE project_id = :project_id + """ + interest_results = await db.fetch_all( + interest_query, values={"project_id": project_dto.project_id} + ) + current_interest_ids = [i.interest_id for i in interest_results] + + new_interest_set = set(new_interest_ids) + current_interest_set = set(current_interest_ids) + + if new_interest_set != current_interest_set: + to_add_interests = new_interest_set - current_interest_set + to_remove_interests = current_interest_set - new_interest_set + + if to_remove_interests: + await db.execute( + """ + DELETE FROM project_interests + WHERE project_id = :project_id + AND interest_id = ANY(:to_remove) + """, + values={ + "project_id": project_dto.project_id, + "to_remove": list(to_remove_interests), + }, + ) + + if to_add_interests: + insert_interest_query = """ + INSERT INTO project_interests (project_id, interest_id) + VALUES (:project_id, :interest_id) + """ + for interest_id in to_add_interests: + await db.execute( + insert_interest_query, + values={ + "project_id": project_dto.project_id, + "interest_id": interest_id, + }, + ) # try to update country info if that information is not present if not self.country: @@ -636,7 +723,6 @@ async def update(self, project_dto: ProjectDTO, db: Database): columns.pop("geometry", None) columns.pop("centroid", None) columns.pop("id", None) - # Update the project in the database await db.execute( self.__table__.update().where(Project.id == self.id).values(**columns) @@ -662,13 +748,6 @@ async def exists(project_id: int, db: Database) -> bool: return True - # def is_favorited(self, user_id: int) -> bool: - # user = session.get(User, user_id) - # if user not in self.favorited: - # return False - - # return True - @staticmethod async def is_favorited(project_id: int, user_id: int, db: Database) -> bool: query = """ @@ -1523,7 +1602,12 @@ async def get_project_and_base_dto(project_id: int, db: Database) -> ProjectDTO: """ teams = await db.fetch_all(teams_query, {"project_id": project_id}) project_dto.project_teams = ( - [ProjectTeamDTO(**team) for team in teams] if teams else [] + [ + ProjectTeamDTO(**{**team, "role": TeamRoles(team["role"]).name}) + for team in teams + ] + if teams + else [] ) custom_editor = await db.fetch_one( @@ -1574,7 +1658,9 @@ async def get_project_and_base_dto(project_id: int, db: Database) -> ProjectDTO: priority_areas_query, {"project_id": project_id} ) project_dto.priority_areas = ( - [area["geojson"] for area in priority_areas] if priority_areas else None + [geojson.loads(area["geojson"]) for area in priority_areas] + if priority_areas + else None ) interests_query = """ @@ -1734,6 +1820,38 @@ async def get_project_campaigns(project_id: int, db: Database): campaign_list = [ListCampaignDTO(**row) for row in rows] return campaign_list + @staticmethod + async def clear_existing_priority_areas(db: Database, project_id: int): + """Clear existing priority area links and delete the corresponding priority areas for the given project ID.""" + + existing_priority_area_ids_query = """ + SELECT priority_area_id + FROM project_priority_areas + WHERE project_id = :project_id; + """ + existing_priority_area_ids = await db.fetch_all( + query=existing_priority_area_ids_query, values={"project_id": project_id} + ) + existing_ids = [ + record["priority_area_id"] for record in existing_priority_area_ids + ] + + clear_links_query = """ + DELETE FROM project_priority_areas + WHERE project_id = :project_id; + """ + await db.execute(query=clear_links_query, values={"project_id": project_id}) + + if existing_ids: + delete_priority_areas_query = """ + DELETE FROM priority_areas + WHERE id = ANY(:ids); + """ + # Pass the list as an array using PostgreSQL's array syntax + await db.execute( + query=delete_priority_areas_query, values={"ids": existing_ids} + ) + # Add index on project geometry Index("idx_geometry", Project.geometry, postgresql_using="gist") diff --git a/backend/models/postgis/project_info.py b/backend/models/postgis/project_info.py index 3479745b2f..8b292709e0 100644 --- a/backend/models/postgis/project_info.py +++ b/backend/models/postgis/project_info.py @@ -72,7 +72,7 @@ async def create_from_dto(cls, dto: ProjectInfoDTO, project_id: int, db: Databas async def update_from_dto(self, dto: ProjectInfoDTO, db: Database): """Updates existing ProjectInfo from supplied DTO""" - self.locale = dto.locale + # self.locale = dto.locale self.name = dto.name self.project_id_str = str(self.project_id) # Allows project_id to be searched @@ -84,6 +84,8 @@ async def update_from_dto(self, dto: ProjectInfoDTO, db: Database): columns = { c.key: getattr(self, c.key) for c in inspect(self).mapper.column_attrs } + columns.pop("project_id", None) + columns.pop("locale", None) query = ( update(ProjectInfo.__table__) .where(ProjectInfo.project_id == self.project_id) diff --git a/backend/models/postgis/team.py b/backend/models/postgis/team.py index 726d1cb94a..b357e66465 100644 --- a/backend/models/postgis/team.py +++ b/backend/models/postgis/team.py @@ -27,6 +27,7 @@ ) from backend.models.postgis.user import User from backend.db import Base, get_session +from sqlalchemy import select session = get_session() @@ -154,7 +155,6 @@ async def create_from_dto(cls, new_team_dto: NewTeamDTO, db: Database): new_member.user_id = new_team_dto.creator new_member.function = TeamMemberFunctions.MANAGER.value new_member.active = True - new_team.members.append(new_member) team = await Team.create(new_team, db) @@ -271,7 +271,9 @@ async def get(team_id: int, db: Database): :param team_id: team ID in scope :return: Team if found otherwise None """ - return db.fetch_one(Team.__table__, Team.id == team_id) + query = select(Team).where(Team.id == team_id) + result = await db.fetch_one(query) + return result async def get_team_by_name(team_name: str, db: Database): """