Skip to content

Commit

Permalink
DynamicEntryPointCommandGroup: Use pydantic to define config model
Browse files Browse the repository at this point in the history
The `DynamicEntryPointCommandGroup` depends on the entry point classes
to implement the `get_cli_options` method to return a dictionary with a
specification of the options to create. The schema of this dictionary
was a custom ad-hoc solution for this purpose. Here we switch to using
pydantic's `BaseModel` to define the `Config` class attribute which
defines the schema for the configuration necessary to construct an
instance of the entry points class.
  • Loading branch information
sphuber committed Nov 9, 2023
1 parent ceefc1e commit d68edcb
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 112 deletions.
36 changes: 26 additions & 10 deletions aiida/cmdline/groups/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,30 @@ def list_options(self, entry_point: str) -> list:
:param entry_point: The entry point.
"""
return [
self.create_option(*item)
for item in self.factory(entry_point).get_cli_options().items() # type: ignore[union-attr]
]
cls = self.factory(entry_point)

if not hasattr(cls, 'Configuration'):
from aiida.common.warnings import warn_deprecation
warn_deprecation(
'Relying on `_get_cli_options` is deprecated. The options should be defined through a '
'`pydantic.BaseModel` that should be assigned to the `Config` class attribute.',
version=3
)
options_spec = self.factory(entry_point).get_cli_options() # type: ignore[union-attr]
else:

options_spec = {}

for key, field_info in cls.Configuration.model_fields.items():
options_spec[key] = {
'required': field_info.is_required(),
'type': field_info.annotation,
'prompt': field_info.title,
'default': field_info.default if field_info.default is not None else None,
'help': field_info.description,
}

return [self.create_option(*item) for item in options_spec.items()]

@staticmethod
def create_option(name, spec: dict) -> t.Callable[[t.Any], t.Any]:
Expand All @@ -136,6 +156,7 @@ def create_option(name, spec: dict) -> t.Callable[[t.Any], t.Any]:
name_dashed = name.replace('_', '-')
option_name = f'--{name_dashed}/--no-{name_dashed}' if is_flag else f'--{name_dashed}'
option_short_name = spec.pop('short_name', None)
option_names = (option_short_name, option_name) if option_short_name else (option_name,)

kwargs = {'cls': spec.pop('cls', InteractiveOption), 'show_default': True, 'is_flag': is_flag, **spec}

Expand All @@ -144,9 +165,4 @@ def create_option(name, spec: dict) -> t.Callable[[t.Any], t.Any]:
if kwargs['cls'] is InteractiveOption and is_flag and default is None:
kwargs['cls'] = functools.partial(InteractiveOption, prompt_fn=lambda ctx: False)

if option_short_name:
option = click.option(option_short_name, option_name, **kwargs)
else:
option = click.option(option_name, **kwargs)

return option
return click.option(*(option_names), **kwargs)
2 changes: 1 addition & 1 deletion aiida/manage/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def create_profile(
"""
from aiida.orm import User

storage_config = {key: kwargs[key] for key in storage_cls.get_cli_options().keys() if key in kwargs}
storage_config = storage_cls.Configuration(**{k: v for k, v in kwargs.items() if v is not None}).dict()
profile: Profile = config.create_profile(name=name, storage_cls=storage_cls, storage_config=storage_config)

with profile_context(profile.name, allow_switch=True):
Expand Down
7 changes: 2 additions & 5 deletions aiida/manage/configuration/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import io
import json
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import uuid

from pydantic import ( # pylint: disable=no-name-in-module
Expand All @@ -32,11 +32,8 @@
field_validator,
)

from aiida.common.exceptions import ConfigurationError
from aiida.common.log import LogLevels

from aiida.common.exceptions import ConfigurationError, EntryPointError, StorageMigrationError
from aiida.common.log import AIIDA_LOGGER
from aiida.common.log import AIIDA_LOGGER, LogLevels

from .options import Option, get_option, get_option_names, parse_option
from .profile import Profile
Expand Down
11 changes: 0 additions & 11 deletions aiida/orm/implementation/storage_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from __future__ import annotations

import abc
import collections
from typing import TYPE_CHECKING, Any, ContextManager, List, Optional, Sequence, TypeVar, Union

if TYPE_CHECKING:
Expand Down Expand Up @@ -92,16 +91,6 @@ def migrate(cls, profile: 'Profile') -> None:
:raises: :class:`~aiida.common.exceptions.StorageMigrationError` if the storage is not initialised.
"""

@classmethod
def get_cli_options(cls) -> collections.OrderedDict:
"""Return the CLI options that would allow to create an instance of this class."""
return collections.OrderedDict(cls._get_cli_options())

@classmethod
@abc.abstractmethod
def _get_cli_options(cls) -> dict[str, Any]:
"""Return the CLI options that would allow to create an instance of this class."""

@abc.abstractmethod
def __init__(self, profile: 'Profile') -> None:
"""Initialize the backend, for this profile.
Expand Down
3 changes: 2 additions & 1 deletion aiida/repository/backend/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import contextlib
import os
import pathlib
import shutil
import typing as t
import uuid
Expand Down Expand Up @@ -65,7 +66,7 @@ def is_initialised(self) -> bool:
def sandbox(self):
"""Return the sandbox instance of this repository."""
if self._sandbox is None:
self._sandbox = SandboxFolder(filepath=self._filepath)
self._sandbox = SandboxFolder(filepath=pathlib.Path(self._filepath) if self._filepath is not None else None)

return self._sandbox

Expand Down
75 changes: 25 additions & 50 deletions aiida/storage/psql_dos/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pathlib
from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence, Set, Union

from pydantic import BaseModel, Field
from sqlalchemy import column, insert, update
from sqlalchemy.orm import Session, scoped_session, sessionmaker

Expand Down Expand Up @@ -72,6 +73,30 @@ class PsqlDosBackend(StorageBackend): # pylint: disable=too-many-public-methods
The `django` backend was removed, to consolidate access to this storage.
"""

class Configuration(BaseModel):
"""Model describing required information to configure an instance of the storage."""

database_engine: str = Field(
title='PostgreSQL engine',
description='The engine to use to connect to the database.',
default='postgresql_psycopg2'
)
database_hostname: str = Field(
title='PostgreSQL hostname', description='The hostname of the PostgreSQL server.', default='localhost'
)
database_port: int = Field(
title='PostgreSQL port', description='The port of the PostgreSQL server.', default=5432
)
database_username: str = Field(
title='PostgreSQL username', description='The username with which to connect to the PostgreSQL server.'
)
database_password: str = Field(
title='PostgreSQL password', description='The password with which to connect to the PostgreSQL server.'
)
database_name: Union[str, None] = Field(
title='PostgreSQL database name', description='The name of the database in the PostgreSQL server.'
)

migrator = PsqlDosMigrator

@classmethod
Expand Down Expand Up @@ -102,56 +127,6 @@ def migrator_context(cls, profile: Profile):
finally:
migrator.close()

@classmethod
def create_config(cls, **kwargs):
"""Create a configuration dictionary based on the CLI options that can be used to initialize an instance."""
return {key: kwargs[key] for key in cls._get_cli_options()}

@classmethod
def _get_cli_options(cls) -> dict:
"""Return the CLI options that would allow to create an instance of this class."""
return {
'database_engine': {
'required': True,
'type': str,
'prompt': 'Postgresql engine',
'default': 'postgresql_psycopg2',
'help': 'The engine to use to connect to the database.',
},
'database_hostname': {
'required': True,
'type': str,
'prompt': 'Postgresql hostname',
'default': 'localhost',
'help': 'The hostname of the PostgreSQL server.',
},
'database_port': {
'required': True,
'type': int,
'prompt': 'Postgresql port',
'default': '5432',
'help': 'The port of the PostgreSQL server.',
},
'database_username': {
'required': True,
'type': str,
'prompt': 'Postgresql username',
'help': 'The username with which to connect to the PostgreSQL server.',
},
'database_password': {
'required': True,
'type': str,
'prompt': 'Postgresql password',
'help': 'The password with which to connect to the PostgreSQL server.',
},
'database_name': {
'required': True,
'type': str,
'prompt': 'Postgresql database name',
'help': 'The name of the database in the PostgreSQL server.',
}
}

def __init__(self, profile: Profile) -> None:
super().__init__(profile)

Expand Down
27 changes: 10 additions & 17 deletions aiida/storage/sqlite_temp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tempfile import mkdtemp
from typing import Any, BinaryIO, Iterator, Sequence

from pydantic import BaseModel, Field
from sqlalchemy import column, insert, update
from sqlalchemy.orm import Session

Expand All @@ -41,6 +42,15 @@ class SqliteTempBackend(StorageBackend): # pylint: disable=too-many-public-meth
Whenever it is instantiated, it creates a fresh storage backend,
and destroys it when it is garbage collected.
"""

class Configuration(BaseModel):

filepath: str = Field(
title='Temporary directory',
description='Temporary directory in which to store data for this backend.',
default_factory=mkdtemp
)

_read_only = False

@staticmethod
Expand Down Expand Up @@ -70,23 +80,6 @@ def create_profile(
}
)

@classmethod
def create_config(cls, filepath: str | None = None):
"""Create a configuration dictionary based on the CLI options that can be used to initialize an instance."""
return {'filepath': filepath or mkdtemp()}

@classmethod
def _get_cli_options(cls) -> dict:
"""Return the CLI options that would allow to create an instance of this class."""
return {
'filepath': {
'required': False,
'type': str,
'prompt': 'Temporary directory',
'help': 'Temporary directory in which to store data for this backend.',
}
}

@classmethod
def version_head(cls) -> str:
return get_schema_version_head()
Expand Down
26 changes: 9 additions & 17 deletions aiida/storage/sqlite_zip/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from zipfile import ZipFile, is_zipfile

from archive_path import ZipPath, extract_file_in_zip
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session

from aiida import __version__
Expand Down Expand Up @@ -63,6 +64,14 @@ class SqliteZipBackend(StorageBackend): # pylint: disable=too-many-public-metho
...
"""

class Configuration(BaseModel):

filepath: str = Field(
title='Filepath of the archive',
description='Filepath of the archive in which to store data for this backend.'
)

_read_only = True

@classmethod
Expand All @@ -89,23 +98,6 @@ def create_profile(path: str | Path, options: dict | None = None) -> Profile:
}
)

@classmethod
def create_config(cls, filepath: str):
"""Create a configuration dictionary based on the CLI options that can be used to initialize an instance."""
return {'path': filepath}

@classmethod
def _get_cli_options(cls) -> dict:
"""Return the CLI options that would allow to create an instance of this class."""
return {
'filepath': {
'required': True,
'type': str,
'prompt': 'Filepath of the archive',
'help': 'Filepath of the archive in which to store data for this backend.',
}
}

@classmethod
def version_profile(cls, profile: Profile) -> Optional[str]:
return read_version(profile.storage_config['path'], search_limit=None)
Expand Down
5 changes: 5 additions & 0 deletions docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ py:obj aiida.storage.psql_dos.orm.ModelType
py:obj aiida.storage.psql_dos.orm.SelfType
py:obj aiida.storage.psql_dos.orm.entities.ModelType
py:obj aiida.storage.psql_dos.orm.entities.SelfType
py:class aiida.storage.psql_dos.backend.Configuration
py:class aiida.storage.sqlite_temp.backend.Configuration
py:class aiida.storage.sqlite_zip.backend.Configuration
py:obj aiida.tools.archive.SelfType
py:obj aiida.tools.archive.EntityType
py:func QueryBuilder._get_ormclass
Expand Down Expand Up @@ -132,6 +135,8 @@ py:func click.shell_completion._start_of_option
py:meth click.Option.get_default
py:meth fail

py:class pydantic.main.BaseModel

py:class requests.models.Response
py:class requests.Response

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ disable = [
"use-dict-literal",
"unnecessary-dunder-call",
]
extension-pkg-whitelist = "pydantic"

[tool.pylint.basic]
good-names = [
Expand Down

0 comments on commit d68edcb

Please sign in to comment.