Skip to content

Commit

Permalink
ORM: Use pydantic to specify a schema for each ORM entity
Browse files Browse the repository at this point in the history
  • Loading branch information
sphuber committed Aug 9, 2024
1 parent fb36862 commit d4c4d94
Show file tree
Hide file tree
Showing 94 changed files with 1,967 additions and 1,380 deletions.
3 changes: 3 additions & 0 deletions docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ py:meth click.Option.get_default
py:meth fail

py:class ComputedFieldInfo
py:class BaseModel
py:class pydantic.fields.Field
py:class pydantic.fields.FieldInfo
py:class pydantic.main.BaseModel
py:class PluggableSchemaValidator

Expand All @@ -157,6 +159,7 @@ py:class frozenset

py:class numpy.bool_
py:class numpy.ndarray
py:class np.ndarray
py:class ndarray

py:class paramiko.proxy.ProxyCommand
Expand Down
11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@ requires-python = '>=3.9'
'process.workflow.workchain' = 'aiida.orm.nodes.process.workflow.workchain:WorkChainNode'
'process.workflow.workfunction' = 'aiida.orm.nodes.process.workflow.workfunction:WorkFunctionNode'

[project.entry-points.'aiida.orm']
'core.auth_info' = 'aiida.orm.authinfos:AuthInfo'
'core.comment' = 'aiida.orm.comments:Comment'
'core.computer' = 'aiida.orm.computers:Computer'
'core.data' = 'aiida.orm.nodes.data.data:Data'
'core.entity' = 'aiida.orm.entities:Entity'
'core.group' = 'aiida.orm.groups:Group'
'core.log' = 'aiida.orm.logs:Log'
'core.node' = 'aiida.orm.nodes.node:Node'
'core.user' = 'aiida.orm.users:User'

[project.entry-points.'aiida.parsers']
'core.arithmetic.add' = 'aiida.parsers.plugins.arithmetic.add:ArithmeticAddParser'
'core.templatereplacer' = 'aiida.parsers.plugins.templatereplacer.parser:TemplatereplacerParser'
Expand Down
20 changes: 9 additions & 11 deletions src/aiida/cmdline/commands/cmd_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def verdi_code():
"""Setup and manage codes."""


def create_code(ctx: click.Context, cls, non_interactive: bool, **kwargs):
def create_code(ctx: click.Context, cls, **kwargs):
"""Create a new `Code` instance."""
try:
instance = cls(**kwargs)
instance = cls.from_model(cls.Model(**kwargs))
except (TypeError, ValueError) as exception:
echo.echo_critical(f'Failed to create instance `{cls}`: {exception}')

Expand Down Expand Up @@ -245,24 +245,22 @@ def export(code, output_file, overwrite, sort):

import yaml

code_data = {}
from aiida.common.pydantic import get_metadata

for key in code.Model.model_fields.keys():
value = getattr(code, key).label if key == 'computer' else getattr(code, key)
data = code.serialize()

# If the attribute is not set, for example ``with_mpi`` do not export it, because the YAML won't be valid for
# use in ``verdi code create`` since ``None`` is not a valid value on the CLI.
if value is not None:
code_data[key] = str(value)
for key, field in code.Model.model_fields.items():
if get_metadata(field, 'exclude_from_cli'):
data.pop(key)

try:
output_file = generate_validate_output_file(
output_file=output_file, entity_label=code.label, overwrite=overwrite, appendix=f'@{code_data["computer"]}'
output_file=output_file, entity_label=code.label, overwrite=overwrite, appendix=f'@{data["computer"]}'
)
except (FileExistsError, IsADirectoryError) as exception:
raise click.BadParameter(str(exception), param_hint='OUTPUT_FILE') from exception

output_file.write_text(yaml.dump(code_data, sort_keys=sort))
output_file.write_text(yaml.dump(data, sort_keys=sort))

echo.echo_success(f'Code<{code.pk}> {code.label} exported to file `{output_file}`.')

Expand Down
1 change: 0 additions & 1 deletion src/aiida/cmdline/commands/cmd_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def verdi_profile():
def command_create_profile(
ctx: click.Context,
storage_cls,
non_interactive: bool,
profile: Profile,
set_as_default: bool = True,
email: str | None = None,
Expand Down
20 changes: 11 additions & 9 deletions src/aiida/cmdline/groups/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,29 +88,25 @@ def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None
command = super().get_command(ctx, cmd_name)
return command

def call_command(self, ctx, cls, **kwargs):
def call_command(self, ctx, cls, non_interactive, **kwargs):
"""Call the ``command`` after validating the provided inputs."""
from pydantic import ValidationError

if hasattr(cls, 'Model'):
# The plugin defines a pydantic model: use it to validate the provided arguments
try:
model = cls.Model(**kwargs)
cls.Model(**kwargs)
except ValidationError as exception:
param_hint = [
f'--{loc.replace("_", "-")}' # type: ignore[union-attr]
for loc in exception.errors()[0]['loc']
]
message = '\n'.join([str(e['ctx']['error']) for e in exception.errors()])
message = '\n'.join([str(e['msg']) for e in exception.errors()])

Check warning on line 104 in src/aiida/cmdline/groups/dynamic.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/cmdline/groups/dynamic.py#L104

Added line #L104 was not covered by tests
raise click.BadParameter(
message,
param_hint=param_hint or 'multiple parameters', # type: ignore[arg-type]
param_hint=param_hint or 'one or more parameters', # type: ignore[arg-type]
) from exception

# Update the arguments with the dictionary representation of the model. This will include any type coercions
# that may have been applied with validators defined for the model.
kwargs.update(**model.model_dump())

return self._command(ctx, cls, **kwargs)

def create_command(self, ctx: click.Context, entry_point: str) -> click.Command:
Expand Down Expand Up @@ -154,6 +150,8 @@ def list_options(self, entry_point: str) -> list:
"""
from pydantic_core import PydanticUndefined

from aiida.common.pydantic import get_metadata

cls = self.factory(entry_point)

if not hasattr(cls, 'Model'):
Expand All @@ -170,6 +168,9 @@ def list_options(self, entry_point: str) -> list:
options_spec = {}

for key, field_info in cls.Model.model_fields.items():
if get_metadata(field_info, 'exclude_from_cli'):
continue

default = field_info.default_factory if field_info.default is PydanticUndefined else field_info.default

# If the annotation has the ``__args__`` attribute it is an instance of a type from ``typing`` and the real
Expand All @@ -194,7 +195,8 @@ def list_options(self, entry_point: str) -> list:
}
for metadata in field_info.metadata:
for metadata_key, metadata_value in metadata.items():
options_spec[key][metadata_key] = metadata_value
if metadata_key in ('priority', 'short_name', 'option_cls'):
options_spec[key][metadata_key] = metadata_value

options_ordered = []

Expand Down
56 changes: 54 additions & 2 deletions src/aiida/common/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,41 @@
import typing as t

from pydantic import Field
from pydantic_core import PydanticUndefined

if t.TYPE_CHECKING:
from pydantic import BaseModel

Check warning on line 11 in src/aiida/common/pydantic.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/common/pydantic.py#L11

Added line #L11 was not covered by tests

from aiida.orm import Entity

Check warning on line 13 in src/aiida/common/pydantic.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/common/pydantic.py#L13

Added line #L13 was not covered by tests


def get_metadata(field_info, key: str, default: t.Any | None = None):
"""Return a the metadata of the given field for a particular key.
:param field_info: The field from which to retrieve the metadata.
:param key: The metadata name.
:param default: Optional default value to return in case the metadata is not defined on the field.
:returns: The metadata if defined, otherwise the default.
"""
for element in field_info.metadata:
if key in element:
return element[key]
return default


def MetadataField( # noqa: N802
default: t.Any | None = None,
default: t.Any = PydanticUndefined,
*,
priority: int = 0,
short_name: str | None = None,
option_cls: t.Any | None = None,
orm_class: type['Entity'] | str | None = None,
orm_to_model: t.Callable[['Entity'], t.Any] | None = None,
model_to_orm: t.Callable[['BaseModel'], t.Any] | None = None,
exclude_to_orm: bool = False,
exclude_from_cli: bool = False,
is_attribute: bool = True,
is_subscriptable: bool = False,
**kwargs,
):
"""Return a :class:`pydantic.fields.Field` instance with additional metadata.
Expand All @@ -37,10 +64,35 @@ class Model(BaseModel):
:param priority: Used to order the list of all fields in the model. Ordering is done from small to large priority.
:param short_name: Optional short name to use for an option on a command line interface.
:param option_cls: The :class:`click.Option` class to use to construct the option.
:param orm_class: The class, or entry point name thereof, to which the field should be converted. If this field is
defined, the value of this field should acccept an integer which will automatically be converted to an instance
of said ORM class using ``orm_class.collection.get(id={field_value})``. This is useful, for example, where a
field represents an instance of a different entity, such as an instance of ``User``. The serialized data would
store the ``pk`` of the user, but the ORM entity instance would receive the actual ``User`` instance with that
primary key.
:param orm_to_model: Optional callable to convert the value of a field from an ORM instance to a model instance.
:param model_to_orm: Optional callable to convert the value of a field from a model instance to an ORM instance.
:param exclude_to_orm: When set to ``True``, this field value will not be passed to the ORM entity constructor
through ``Entity.from_model``.
:param exclude_to_orm: When set to ``True``, this field value will not be exposed on the CLI command that is
dynamically generated to create a new instance.
:param is_attribute: Whether the field is stored as an attribute.
:param is_subscriptable: Whether the field can be indexed like a list or dictionary.
"""
field_info = Field(default, **kwargs)

for key, value in (('priority', priority), ('short_name', short_name), ('option_cls', option_cls)):
for key, value in (
('priority', priority),
('short_name', short_name),
('option_cls', option_cls),
('orm_class', orm_class),
('orm_to_model', orm_to_model),
('model_to_orm', model_to_orm),
('exclude_to_orm', exclude_to_orm),
('exclude_from_cli', exclude_from_cli),
('is_attribute', is_attribute),
('is_subscriptable', is_subscriptable),
):
if value is not None:
field_info.metadata.append({key: value})

Expand Down
100 changes: 66 additions & 34 deletions src/aiida/orm/authinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@
###########################################################################
"""Module for the `AuthInfo` ORM class."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Optional, Type

from aiida.common import exceptions
from aiida.common.pydantic import MetadataField
from aiida.manage import get_manager
from aiida.plugins import TransportFactory

from . import entities, users
from .fields import add_field
from .computers import Computer
from .users import User

if TYPE_CHECKING:
from aiida.orm import Computer, User
from aiida.orm.implementation import StorageBackend
from aiida.orm.implementation.authinfos import BackendAuthInfo # noqa: F401
from aiida.transports import Transport
Expand All @@ -45,51 +48,60 @@ class AuthInfo(entities.Entity['BackendAuthInfo', AuthInfoCollection]):
"""ORM class that models the authorization information that allows a `User` to connect to a `Computer`."""

_CLS_COLLECTION = AuthInfoCollection
PROPERTY_WORKDIR = 'workdir'

__qb_fields__ = [
add_field(
'enabled',
dtype=bool,
class Model(entities.Entity.Model):
computer: int = MetadataField(
description='The PK of the computer',
is_attribute=False,
doc='Whether the instance is enabled',
),
add_field(
'auth_params',
dtype=Dict[str, Any],
orm_class=Computer,
orm_to_model=lambda auth_info: auth_info.computer.pk, # type: ignore[attr-defined]
)
user: int = MetadataField(
description='The PK of the user',
is_attribute=False,
doc='Dictionary of authentication parameters',
),
add_field(
'metadata',
dtype=Dict[str, Any],
orm_class=User,
orm_to_model=lambda auth_info: auth_info.user.pk, # type: ignore[attr-defined]
)
enabled: bool = MetadataField(
True,
description='Whether the instance is enabled',
is_attribute=False,
doc='Dictionary of metadata',
),
add_field(
'computer_pk',
dtype=int,
)
auth_params: Dict[str, Any] = MetadataField(
default_factory=dict,
description='Dictionary of authentication parameters',
is_attribute=False,
doc='The PK of the computer',
),
add_field(
'user_pk',
dtype=int,
)
metadata: Dict[str, Any] = MetadataField(
default_factory=dict,
description='Dictionary of metadata',
is_attribute=False,
doc='The PK of the user',
),
]

PROPERTY_WORKDIR = 'workdir'

def __init__(self, computer: 'Computer', user: 'User', backend: Optional['StorageBackend'] = None) -> None:
)

def __init__(
self,
computer: 'Computer',
user: 'User',
enabled: bool = True,
auth_params: Dict[str, Any] | None = None,
metadata: Dict[str, Any] | None = None,
backend: Optional['StorageBackend'] = None,
) -> None:
"""Create an `AuthInfo` instance for the given computer and user.
:param computer: a `Computer` instance
:param user: a `User` instance
:param backend: the backend to use for the instance, or use the default backend if None
"""
backend = backend or get_manager().get_profile_storage()
model = backend.authinfos.create(computer=computer.backend_entity, user=user.backend_entity)
model = backend.authinfos.create(
computer=computer.backend_entity,
user=user.backend_entity,
enabled=enabled,
auth_params=auth_params or {},
metadata=metadata or {},
)
super().__init__(model)

def __str__(self) -> str:
Expand All @@ -98,6 +110,18 @@ def __str__(self) -> str:

return f'AuthInfo for {self.user.email} on {self.computer.label} [DISABLED]'

def __eq__(self, other) -> bool:
if not isinstance(other, AuthInfo):
return False

Check warning on line 115 in src/aiida/orm/authinfos.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/orm/authinfos.py#L114-L115

Added lines #L114 - L115 were not covered by tests

return (

Check warning on line 117 in src/aiida/orm/authinfos.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/orm/authinfos.py#L117

Added line #L117 was not covered by tests
self.user.pk == other.user.pk
and self.computer.pk == other.computer.pk
and self.enabled == other.enabled
and self.auth_params == other.auth_params
and self.metadata == other.metadata
)

@property
def enabled(self) -> bool:
"""Return whether this instance is enabled.
Expand Down Expand Up @@ -126,6 +150,14 @@ def user(self) -> 'User':
"""Return the user associated with this instance."""
return entities.from_backend_entity(users.User, self._backend_entity.user)

@property
def auth_params(self) -> Dict[str, Any]:
return self._backend_entity.get_auth_params()

@property
def metadata(self) -> Dict[str, Any]:
return self._backend_entity.get_metadata()

def get_auth_params(self) -> Dict[str, Any]:
"""Return the dictionary of authentication parameters
Expand Down
Loading

0 comments on commit d4c4d94

Please sign in to comment.