diff --git a/basemodels/manifest/data/taskdata.py b/basemodels/manifest/data/taskdata.py index 515dc5a..d3c5ee7 100644 --- a/basemodels/manifest/data/taskdata.py +++ b/basemodels/manifest/data/taskdata.py @@ -2,7 +2,7 @@ from uuid import UUID import requests -from pydantic import BaseModel, HttpUrl, validate_model, ValidationError, validator +from pydantic import BaseModel, HttpUrl, validate_model, ValidationError, validator, root_validator from requests import RequestException from basemodels.constants import SUPPORTED_CONTENT_TYPES @@ -21,23 +21,27 @@ class TaskDataEntry(BaseModel): { "task_key": "407fdd93-687a-46bb-b578-89eb96b4109d", "datapoint_uri": "https://domain.com/file1.jpg", + "datapoint_text": {}, "datapoint_hash": "f4acbe8562907183a484498ba901bfe5c5503aaa" }, { "task_key": "20bd4f3e-4518-4602-b67a-1d8dfabcce0c", "datapoint_uri": "https://domain.com/file2.jpg", + "datapoint_text": {}, "datapoint_hash": "f4acbe8562907183a484498ba901bfe5c5503aaa" } ] """ task_key: Optional[UUID] - datapoint_uri: HttpUrl + datapoint_uri: Optional[HttpUrl] + datapoint_text: Optional[Dict[str, str]] @validator("datapoint_uri", always=True) def validate_datapoint_uri(cls, value): - if len(value) < 10: + if value and len(value) < 10: raise ValidationError("datapoint_uri need to be at least 10 char length.") + return value @validator("metadata") def validate_metadata(cls, value): @@ -55,6 +59,17 @@ def validate_metadata(cls, value): datapoint_hash: Optional[str] metadata: Optional[Dict[str, Optional[Union[str, int, float, Dict[str, Any]]]]] + @root_validator + def validate_datapoint_text(cls, values): + """ + Validate datapoint_uri. + + Raise error if no datapoint_text and no value for URI. + """ + if not values.get("datapoint_uri") and not values.get("datapoint_text"): + raise ValueError("datapoint_uri is missing.") + return values + def validate_content_type(uri: str) -> None: """Validate uri content type""" diff --git a/basemodels/manifest/manifest.py b/basemodels/manifest/manifest.py index 1d4a756..b885077 100644 --- a/basemodels/manifest/manifest.py +++ b/basemodels/manifest/manifest.py @@ -10,7 +10,7 @@ from .data.requester_question_example import validate_requester_example_image from .data.requester_restricted_answer_set import validate_requester_restricted_answer_set_uris from .data.taskdata import validate_taskdata_entry -from pydantic import BaseModel, validator, ValidationError, validate_model, HttpUrl, AnyHttpUrl +from pydantic import BaseModel, validator, ValidationError, validate_model, HttpUrl, AnyHttpUrl, root_validator from pydantic.fields import Field from decimal import Decimal from basemodels.manifest.restricted_audience import RestrictedAudience @@ -109,9 +109,27 @@ class TaskData(BaseModel): """objects within taskdata list in Manifest""" task_key: UUID - datapoint_uri: AnyHttpUrl + datapoint_uri: Optional[AnyHttpUrl] + datapoint_text: Optional[Dict[str, str]] datapoint_hash: str = Field(..., min_length=10, strip_whitespace=True) + @validator("datapoint_uri", always=True) + def validate_datapoint_uri(cls, value): + if value and len(value) < 10: + raise ValidationError("datapoint_uri need to be at least 10 char length.") + return value + + @root_validator + def validate_datapoint_text(cls, values): + """ + Validate datapoint_uri. + + Raise error if no datapoint_text and no value for URI. + """ + if not values.get("datapoint_uri") and not values.get("datapoint_text"): + raise ValueError("datapoint_uri is missing.") + return values + class RequestConfig(Model): """definition of the request_config object in manifest""" diff --git a/pyproject.toml b/pyproject.toml index 8d96d5d..0464607 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "hmt-basemodels" -version = "0.2.1" +version = "0.2.2" description = "" authors = ["Intuition Machines, Inc "] packages = [ diff --git a/setup.py b/setup.py index cc8f8d0..58f83e9 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setuptools.setup( name="hmt-basemodels", - version="0.2.1", + version="0.2.2", author="HUMAN Protocol", description="Common data models shared by various components of the Human Protocol stack", url="https://github.com/hCaptcha/hmt-basemodels", diff --git a/tests/test_manifest_validation.py b/tests/test_manifest_validation.py index 89031f2..9f297f9 100755 --- a/tests/test_manifest_validation.py +++ b/tests/test_manifest_validation.py @@ -912,7 +912,9 @@ def test_valid_entry_is_true(self): TaskDataEntry(**taskdata) taskdata["datapoint_text"] = {"en": "Question to test with"} - TaskDataEntry(**taskdata) + taskdata.pop("datapoint_uri") + taskdata_cons = TaskDataEntry(**taskdata) + self.assertEqual(taskdata_cons.datapoint_text, {"en": "Question to test with"}) with self.assertRaises(ValidationError): taskdata["datapoint_text"] = {}