Skip to content

Commit

Permalink
WIP full verifciation
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Jun 26, 2024
1 parent 98e569d commit 8197e72
Show file tree
Hide file tree
Showing 27 changed files with 29,980 additions and 394 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ aiida-sssp-workflow = "aiida_sssp_workflow.cli:cmd_root"
"sssp_workflow.convergence.phonon_frequencies" = "aiida_sssp_workflow.workflows.convergence.phonon_frequencies:ConvergencePhononFrequenciesWorkChain"
"sssp_workflow.convergence.pressure" = "aiida_sssp_workflow.workflows.convergence.pressure:ConvergencePressureWorkChain"
"sssp_workflow.convergence.bands" = "aiida_sssp_workflow.workflows.convergence.bands:ConvergenceBandsWorkChain"
"sssp_workflow.verification" = "aiida_sssp_workflow.workflows.verifications:FullVerificationWorkChain"
"sssp_workflow.verification" = "aiida_sssp_workflow.workflows.verification:FullVerificationWorkChain"

[project.urls]
Documentation = "https://aiida-sssp-workflow.readthedocs.io/"
Expand Down
6 changes: 3 additions & 3 deletions src/aiida_sssp_workflow/protocol/control.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ quick:
test:
description: test only

wfc_scan: [30, 35] # at fix dual
nc_dual_scan: [] # at fix wfc
nonnc_dual_scan: [] # at fix rho
wfc_scan: [20, 30] # at fix dual
nc_dual_scan: [3.0, 4.0] # at fix wfc
nonnc_dual_scan: [6.0, 8.0] # at fix rho
nonnc_high_dual_scan: [8.0, 12.0]
56 changes: 56 additions & 0 deletions src/aiida_sssp_workflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,59 @@ def get_default_mpi_options(
"max_wallclock_seconds": int(max_wallclock_seconds),
"withmpi": with_mpi,
}

def serialize_data(data):
from aiida.orm import (
AbstractCode,
BaseType,
Data,
Dict,
KpointsData,
List,
RemoteData,
SinglefileData,
)
from aiida.plugins import DataFactory

StructureData = DataFactory("core.structure")
UpfData = DataFactory("pseudo.upf")

if isinstance(data, dict):
return {key: serialize_data(value) for key, value in data.items()}

if isinstance(data, BaseType):
return data.value

if isinstance(data, AbstractCode):
return data.full_label

if isinstance(data, Dict):
return data.get_dict()

if isinstance(data, List):
return data.get_list()

if isinstance(data, StructureData):
return data.get_formula()

if isinstance(data, UpfData):
return f"{data.element}<md5={data.md5}>"

if isinstance(data, RemoteData):
# For `RemoteData` we compute the hash of the repository. The value returned by `Node._get_hash` is not
# useful since it includes the hash of the absolute filepath and the computer UUID which vary between tests
return data.base.repository.hash()

if isinstance(data, KpointsData):
try:
return data.get_kpoints().tolist()
except AttributeError:
return data.get_kpoints_mesh()

if isinstance(data, SinglefileData):
return data.get_content()

if isinstance(data, Data):
return data.base.caching._get_hash()

return data
27 changes: 26 additions & 1 deletion src/aiida_sssp_workflow/utils/protocol.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import List, Tuple
import yaml
from importlib import resources

from aiida_sssp_workflow.utils.pseudo import DualType, get_dual_type

def get_protocol(category, name=None):

def get_protocol(category: str, name: str | None=None):
"""Load and read protocol from faml file to a verbose dict
if name not set, return whole protocol."""
import_path = resources.path("aiida_sssp_workflow.protocol", f"{category}.yml")
Expand All @@ -13,3 +16,25 @@ def get_protocol(category, name=None):
return protocol_dict[name]
else:
return protocol_dict

def generate_cutoff_list(protocol_name: str, element: str, pp_type: str) -> List[Tuple[int, int]]:
"""From the control protocol name, get the cutoff list
"""
match get_dual_type(pp_type, element):
case DualType.NC:
dual_type = 'nc_dual_scan'
case DualType.AUGLOW:
dual_type = 'nonnc_dual_scan'
case DualType.AUGHIGH:
dual_type = 'nonnc_high_dual_scan'

dual_scan_list = get_protocol('control', protocol_name)[dual_type]
if len(dual_scan_list) > 0:
max_dual = int(max(dual_scan_list))
else:
max_dual = 8

ecutwfc_list = get_protocol('control', protocol_name)['wfc_scan']

return [(e, e*max_dual) for e in ecutwfc_list]

12 changes: 12 additions & 0 deletions src/aiida_sssp_workflow/utils/pseudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,18 @@ class PseudoInfo(BaseModel):
# source_lib: str
# ...

class DualType(Enum):
NC = "nc"
AUGLOW = "charge augmentation low"
AUGHIGH = "charge augmentation high"

def get_dual_type(pp_type: str, element: str) -> DualType:
if element in HIGH_DUAL_ELEMENTS and pp_type != 'nc':
return DualType.AUGHIGH
elif pp_type == 'nc':
return DualType.NC
else:
return DualType.AUGLOW

def extract_pseudo_info(pseudo_text: str) -> PseudoInfo:
"""Giving a pseudo, extract the pseudo info and return as a `PseudoInfo` object"""
Expand Down
16 changes: 11 additions & 5 deletions src/aiida_sssp_workflow/workflows/convergence/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,21 +234,30 @@ def get_builder(
pseudo: Union[Path, UpfData],
protocol: str,
cutoff_list: list,
configuration: str,
configuration: str | None = None,
clean_workdir: bool = True,
) -> ProcessBuilder:
"""Generate builder for the generic convergence workflow"""
builder = super().get_builder()
builder.protocol = orm.Str(protocol)

if configuration is not None:
builder.configuration = orm.Str(configuration)
configuration_name = configuration

if ret := is_valid_convergence_configuration(configuration):
raise ValueError(ret)
else:
configuration_name = "default"

# Set the default label and description
# The default label is set to be the base file name of PP
# The description include which configuration and which protocol is using.
builder.metadata.label = (
pseudo.filename if isinstance(pseudo, UpfData) else pseudo.name
)
builder.metadata.description = (
f"Run on protocol '{protocol}' and configuration '{configuration}'"
f"Run on protocol '{protocol}' and configuration '{configuration_name}'"
)

if isinstance(pseudo, Path):
Expand All @@ -259,11 +268,8 @@ def get_builder(
if ret := is_valid_cutoff_list(cutoff_list):
raise ValueError(ret)

if ret := is_valid_convergence_configuration(configuration):
raise ValueError(ret)

builder.cutoff_list = orm.List(list=cutoff_list)
builder.configuration = orm.Str(configuration)
builder.clean_workdir = orm.Bool(clean_workdir)

return builder
Expand Down
2 changes: 1 addition & 1 deletion src/aiida_sssp_workflow/workflows/convergence/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def get_builder(
pseudo: Union[Path, UpfData],
protocol: str,
cutoff_list: list,
configuration: str,
code: orm.AbstractCode,
configuration: str | None = None,
parallelization: dict | None = None,
mpi_options: dict | None = None,
clean_workdir: bool = True, # default to clean workdir
Expand Down
2 changes: 1 addition & 1 deletion src/aiida_sssp_workflow/workflows/convergence/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def get_builder(
pseudo: Path,
protocol: str,
cutoff_list: list,
configuration: str,
code: orm.AbstractCode,
configuration: str | None = None,
parallelization: dict | None = None,
mpi_options: dict | None = None,
clean_workdir: bool = False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def get_builder(
pseudo: Union[Path, UpfData],
protocol: str,
cutoff_list: list,
configuration: str,
code: orm.AbstractCode,
configuration: str | None = None,
bulk_parallelization: dict | None = None,
bulk_mpi_options: dict | None = None,
atom_parallelization: dict | None = None,
Expand Down
2 changes: 1 addition & 1 deletion src/aiida_sssp_workflow/workflows/convergence/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def get_builder(
pseudo: Union[Path, UpfData],
protocol: str,
cutoff_list: list,
configuration: str,
code: orm.AbstractCode,
configuration: str | None = None,
parallelization: dict | None = None,
mpi_options: dict | None = None,
clean_workdir: bool = True, # default to clean workdir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ def get_builder(
pseudo: Union[Path, UpfData],
protocol: str,
cutoff_list: list,
configuration: str,
pw_code: orm.AbstractCode,
ph_code: orm.AbstractCode,
configuration: str | None = None,
pw_parallelization: dict | None = None,
ph_settings: dict | None = None,
pw_mpi_options: dict | None = None,
ph_settings: dict | None = None,
ph_mpi_options: dict | None = None,
clean_workdir: bool = True, # default to clean workdir
) -> ProcessBuilder:
Expand Down
4 changes: 2 additions & 2 deletions src/aiida_sssp_workflow/workflows/convergence/pressure.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def define(cls, spec):
@classmethod
def get_builder(
cls,
code: orm.AbstractCode,
pseudo: Union[Path, UpfData],
protocol: str,
cutoff_list: list,
configuration: str,
code: orm.AbstractCode,
configuration: str | None = None,
parallelization: dict | None = None,
mpi_options: dict | None = None,
clean_workdir: bool = True, # clean workdir by default
Expand Down
14 changes: 6 additions & 8 deletions src/aiida_sssp_workflow/workflows/transferability/eos.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
"""Workchain to calculate delta factor of specific psp"""

from typing import Tuple
from pathlib import Path

Expand Down Expand Up @@ -41,7 +40,6 @@ class TransferabilityEOSWorkChain(_BaseMeasureWorkChain):
_NBANDS_FACTOR_FOR_LAN = 2.0

_EVALUATE_WORKCHAIN = MetricWorkChain
_DEFAULT_WAVEFUNCTION_CUTOFF = 200

@classmethod
def define(cls, spec):
Expand Down Expand Up @@ -175,7 +173,7 @@ def get_pseudos(self, configuration) -> dict:
def _setup_protocol(self):
"""unzip and parse protocol parameters to context"""
protocol = get_protocol(
category="transferability", name=self.inputs.protocol.value
category="eos", name=self.inputs.protocol.value
)
self.ctx.protocol = protocol

Expand Down Expand Up @@ -207,12 +205,12 @@ def get_builder(
) -> ProcessBuilder:
"""Return a builder to run this EOS convergence workchain"""
builder = super().get_builder(
code,
code,
pseudo,
protocol,
cutoffs,
mpi_options,
parallelization,
protocol,
cutoffs,
mpi_options,
parallelization,
clean_workdir,
)

Expand Down
Loading

0 comments on commit 8197e72

Please sign in to comment.