diff --git a/pyproject.toml b/pyproject.toml index a2cee82b..4cb067f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ aiida-sssp-workflow = "aiida_sssp_workflow.cli:cmd_root" [project.entry-points."aiida.workflows"] "sssp_workflow.measure.transferability" = "aiida_sssp_workflow.workflows.measure.transferability:EOSTransferabilityWorkChain" -"sssp_workflow.measure_bands" = "aiida_sssp_workflow.workflows.measure.bands:BandsMeasureWorkChain" +"sssp_workflow.measure.bands" = "aiida_sssp_workflow.workflows.measure.bands:BandStructureWorkChain" "sssp_workflow.convergence.caching" = "aiida_sssp_workflow.workflows.convergence.caching:_CachingConvergenceWorkChain" "sssp_workflow.convergence.eos" = "aiida_sssp_workflow.workflows.convergence.eos:ConvergenceEOSWorkChain" "sssp_workflow.convergence.cohesive_energy" = "aiida_sssp_workflow.workflows.convergence.cohesive_energy:ConvergenceCohesiveEnergyWorkChain" diff --git a/src/aiida_sssp_workflow/protocol/bands.yml b/src/aiida_sssp_workflow/protocol/bands.yml index 5434715c..b3339765 100644 --- a/src/aiida_sssp_workflow/protocol/bands.yml +++ b/src/aiida_sssp_workflow/protocol/bands.yml @@ -6,10 +6,10 @@ balanced: occupations: smearing degauss: 0.02 smearing: cold - conv_thr_per_atom: 1.0e-10 - kpoints_distance_scf: 0.2 - kpoints_distance_bands: 0.2 - kpoints_distance_band_structure: 0.025 + mixing_beta: 0.4 + conv_thr_per_atom: 1.0e-8 + kpoints_distance: 0.2 + kpoints_distance_bs: 0.025 init_nbands_factor: 5 fermi_shift: 15.0 @@ -21,10 +21,10 @@ test: occupations: smearing degauss: 0.0045 smearing: fd + mixing_beta: 0.4 conv_thr_per_atom: 1.0e-6 - kpoints_distance_scf: 0.25 - kpoints_distance_bands: 0.25 - kpoints_distance_band_structure: 0.10 + kpoints_distance: 0.25 + kpoints_distance_bs: 0.10 init_nbands_factor: 3 fermi_shift: 10.0 diff --git a/src/aiida_sssp_workflow/statics/generatemapping.py b/src/aiida_sssp_workflow/statics/generatemapping.py index bf4b906b..b0acd628 100644 --- a/src/aiida_sssp_workflow/statics/generatemapping.py +++ b/src/aiida_sssp_workflow/statics/generatemapping.py @@ -8,103 +8,12 @@ ACTINIDE_ELEMENTS, LANTHANIDE_ELEMENTS, NO_GS_CONF_ELEMENTS, + ALL_ELEMENTS, + UNSUPPORTED_ELEMENTS, ) MAPPING_FILE_PATH = Path(__file__).parent / "structures" / "mapping.json" -ALL_ELEMENTS = ( - [ - "Ag", - "Al", - "Ar", - "As", - "At", - "Au", - "Ba", - "Be", - "B", - "Bi", - "Br", - "Ca", - "Cd", - "Ce", - "C", - "Cl", - "Co", - "Cr", - "Cs", - "Cu", - "Dy", - "Er", - "Eu", - "Fe", - "F", - "Ga", - "Gd", - "Ge", - "He", - "Hf", - "H", - "Hg", - "Ho", - "I", - "In", - "Ir", - "K", - "Kr", - "La", - "Li", - "Lu", - "Mg", - "Mn", - "Mo", - "Na", - "Nb", - "Nd", - "Ne", - "N", - "Ni", - "O", - "Os", - "Pb", - "Pd", - "P", - "Pm", - "Po", - "Pr", - "Pt", - "Rb", - "Re", - "Rh", - "Rn", - "Ru", - "Sb", - "Sc", - "Se", - "S", - "Si", - "Sm", - "Sn", - "Sr", - "Ta", - "Tb", - "Tc", - "Te", - "Ti", - "Tl", - "Tm", - "V", - "W", - "Xe", - "Yb", - "Y", - "Zn", - "Zr", - ] - + ACTINIDE_ELEMENTS - + NO_GS_CONF_ELEMENTS -) - def run(): d = {} @@ -115,7 +24,10 @@ def run(): # We use DC (Diamond Cubic) for the convergence test for all elements. # Because it usually give the lagrest cutoff energy, which is the most strict test. for e in ALL_ELEMENTS: - if e in NO_GS_CONF_ELEMENTS: + if e in UNSUPPORTED_ELEMENTS: + band_configuration = "N/A" + convergence_configuration = "N/A" + elif e in NO_GS_CONF_ELEMENTS: # we don't have At in typical band_configuration = "DC" convergence_configuration = "DC" diff --git a/src/aiida_sssp_workflow/statics/structures/mapping.json b/src/aiida_sssp_workflow/statics/structures/mapping.json index e73d77fc..10aac52f 100644 --- a/src/aiida_sssp_workflow/statics/structures/mapping.json +++ b/src/aiida_sssp_workflow/statics/structures/mapping.json @@ -1,62 +1,66 @@ { "_commont": "The mapping from element to configuration for convergence and bands verification.", - "Ag": { + "H": { "band": "GS", "convergence": "DC" }, - "Al": { + "He": { "band": "GS", "convergence": "DC" }, - "Ar": { + "Li": { "band": "GS", "convergence": "DC" }, - "As": { + "Be": { "band": "GS", "convergence": "DC" }, - "At": { - "band": "DC", + "B": { + "band": "GS", "convergence": "DC" }, - "Au": { + "C": { "band": "GS", "convergence": "DC" }, - "Ba": { + "N": { "band": "GS", "convergence": "DC" }, - "Be": { + "O": { "band": "GS", "convergence": "DC" }, - "B": { + "F": { "band": "GS", "convergence": "DC" }, - "Bi": { + "Ne": { "band": "GS", "convergence": "DC" }, - "Br": { + "Na": { "band": "GS", "convergence": "DC" }, - "Ca": { + "Mg": { "band": "GS", "convergence": "DC" }, - "Cd": { + "Al": { "band": "GS", "convergence": "DC" }, - "Ce": { - "band": "LAN", + "Si": { + "band": "GS", "convergence": "DC" }, - "C": { + "P": { + "band": "GS", + "convergence": "DC" + }, + "S": { "band": "GS", "convergence": "DC" }, @@ -64,284 +68,288 @@ "band": "GS", "convergence": "DC" }, - "Co": { + "Ar": { "band": "GS", "convergence": "DC" }, - "Cr": { + "K": { "band": "GS", "convergence": "DC" }, - "Cs": { + "Ca": { "band": "GS", "convergence": "DC" }, - "Cu": { + "Sc": { "band": "GS", "convergence": "DC" }, - "Dy": { - "band": "LAN", + "Ti": { + "band": "GS", "convergence": "DC" }, - "Er": { - "band": "LAN", + "V": { + "band": "GS", "convergence": "DC" }, - "Eu": { - "band": "LAN", + "Cr": { + "band": "GS", "convergence": "DC" }, - "Fe": { + "Mn": { "band": "GS", "convergence": "DC" }, - "F": { + "Fe": { "band": "GS", "convergence": "DC" }, - "Ga": { + "Co": { "band": "GS", "convergence": "DC" }, - "Gd": { - "band": "LAN", + "Ni": { + "band": "GS", "convergence": "DC" }, - "Ge": { + "Cu": { "band": "GS", "convergence": "DC" }, - "He": { + "Zn": { "band": "GS", "convergence": "DC" }, - "Hf": { + "Ga": { "band": "GS", "convergence": "DC" }, - "H": { + "Ge": { "band": "GS", "convergence": "DC" }, - "Hg": { + "As": { "band": "GS", "convergence": "DC" }, - "Ho": { - "band": "LAN", + "Se": { + "band": "GS", "convergence": "DC" }, - "I": { + "Br": { "band": "GS", "convergence": "DC" }, - "In": { + "Kr": { "band": "GS", "convergence": "DC" }, - "Ir": { + "Rb": { "band": "GS", "convergence": "DC" }, - "K": { + "Sr": { "band": "GS", "convergence": "DC" }, - "Kr": { + "Y": { "band": "GS", "convergence": "DC" }, - "La": { - "band": "LAN", + "Zr": { + "band": "GS", "convergence": "DC" }, - "Li": { + "Nb": { "band": "GS", "convergence": "DC" }, - "Lu": { - "band": "LAN", + "Mo": { + "band": "GS", "convergence": "DC" }, - "Mg": { + "Tc": { "band": "GS", "convergence": "DC" }, - "Mn": { + "Ru": { "band": "GS", "convergence": "DC" }, - "Mo": { + "Rh": { "band": "GS", "convergence": "DC" }, - "Na": { + "Pd": { "band": "GS", "convergence": "DC" }, - "Nb": { + "Ag": { "band": "GS", "convergence": "DC" }, - "Nd": { - "band": "LAN", + "Cd": { + "band": "GS", "convergence": "DC" }, - "Ne": { + "In": { "band": "GS", "convergence": "DC" }, - "N": { + "Sn": { "band": "GS", "convergence": "DC" }, - "Ni": { + "Sb": { "band": "GS", "convergence": "DC" }, - "O": { + "Te": { "band": "GS", "convergence": "DC" }, - "Os": { + "I": { "band": "GS", "convergence": "DC" }, - "Pb": { + "Xe": { "band": "GS", "convergence": "DC" }, - "Pd": { + "Cs": { "band": "GS", "convergence": "DC" }, - "P": { + "Ba": { "band": "GS", "convergence": "DC" }, - "Pm": { + "La": { "band": "LAN", "convergence": "DC" }, - "Po": { - "band": "GS", + "Ce": { + "band": "LAN", "convergence": "DC" }, "Pr": { "band": "LAN", "convergence": "DC" }, - "Pt": { - "band": "GS", + "Nd": { + "band": "LAN", "convergence": "DC" }, - "Rb": { - "band": "GS", + "Pm": { + "band": "LAN", "convergence": "DC" }, - "Re": { - "band": "GS", + "Sm": { + "band": "LAN", "convergence": "DC" }, - "Rh": { - "band": "GS", + "Eu": { + "band": "LAN", "convergence": "DC" }, - "Rn": { - "band": "GS", + "Gd": { + "band": "LAN", "convergence": "DC" }, - "Ru": { - "band": "GS", + "Tb": { + "band": "LAN", "convergence": "DC" }, - "Sb": { - "band": "GS", + "Dy": { + "band": "LAN", "convergence": "DC" }, - "Sc": { - "band": "GS", + "Ho": { + "band": "LAN", "convergence": "DC" }, - "Se": { - "band": "GS", + "Er": { + "band": "LAN", "convergence": "DC" }, - "S": { - "band": "GS", + "Tm": { + "band": "LAN", "convergence": "DC" }, - "Si": { - "band": "GS", + "Yb": { + "band": "LAN", "convergence": "DC" }, - "Sm": { + "Lu": { "band": "LAN", "convergence": "DC" }, - "Sn": { + "Hf": { "band": "GS", "convergence": "DC" }, - "Sr": { + "Ta": { "band": "GS", "convergence": "DC" }, - "Ta": { + "W": { "band": "GS", "convergence": "DC" }, - "Tb": { - "band": "LAN", + "Re": { + "band": "GS", "convergence": "DC" }, - "Tc": { + "Os": { "band": "GS", "convergence": "DC" }, - "Te": { + "Ir": { "band": "GS", "convergence": "DC" }, - "Ti": { + "Pt": { "band": "GS", "convergence": "DC" }, - "Tl": { + "Au": { "band": "GS", "convergence": "DC" }, - "Tm": { - "band": "LAN", + "Hg": { + "band": "GS", "convergence": "DC" }, - "V": { + "Tl": { "band": "GS", "convergence": "DC" }, - "W": { + "Pb": { "band": "GS", "convergence": "DC" }, - "Xe": { + "Bi": { "band": "GS", "convergence": "DC" }, - "Yb": { - "band": "LAN", + "Po": { + "band": "GS", "convergence": "DC" }, - "Y": { - "band": "GS", + "At": { + "band": "DC", "convergence": "DC" }, - "Zn": { + "Rn": { "band": "GS", "convergence": "DC" }, - "Zr": { - "band": "GS", + "Fr": { + "band": "DC", + "convergence": "DC" + }, + "Ra": { + "band": "DC", "convergence": "DC" }, "Ac": { @@ -377,39 +385,51 @@ "convergence": "DC" }, "Bk": { - "band": "DC", - "convergence": "DC" + "band": "N/A", + "convergence": "N/A" }, "Cf": { - "band": "DC", - "convergence": "DC" + "band": "N/A", + "convergence": "N/A" }, "Es": { - "band": "DC", - "convergence": "DC" + "band": "N/A", + "convergence": "N/A" }, "Fm": { - "band": "DC", - "convergence": "DC" + "band": "N/A", + "convergence": "N/A" }, "Md": { - "band": "DC", - "convergence": "DC" + "band": "N/A", + "convergence": "N/A" }, "No": { - "band": "DC", - "convergence": "DC" + "band": "N/A", + "convergence": "N/A" }, "Lr": { - "band": "DC", - "convergence": "DC" + "band": "N/A", + "convergence": "N/A" }, - "Fr": { - "band": "DC", - "convergence": "DC" + "Rf": { + "band": "N/A", + "convergence": "N/A" }, - "Ra": { - "band": "DC", - "convergence": "DC" + "Db": { + "band": "N/A", + "convergence": "N/A" + }, + "Sg": { + "band": "N/A", + "convergence": "N/A" + }, + "Bh": { + "band": "N/A", + "convergence": "N/A" + }, + "Hs": { + "band": "N/A", + "convergence": "N/A" } } diff --git a/src/aiida_sssp_workflow/utils/element.py b/src/aiida_sssp_workflow/utils/element.py index 60bec342..2528f8df 100644 --- a/src/aiida_sssp_workflow/utils/element.py +++ b/src/aiida_sssp_workflow/utils/element.py @@ -66,6 +66,7 @@ HIGH_DUAL_ELEMENTS = ["O", "Fe", "Mn", "Hf", "Co", "Ni", "Cr"] # The element that not has unaries structure from nat.phys.rev (maxium to Z=108) +# TODO: should called AE_REF_DATA_UNSUPPORTED_ELEMENTS, also the elements that we don't have structure data UNSUPPORTED_ELEMENTS = [ "Bk", "Cf", @@ -80,3 +81,115 @@ "Bh", "Hs", ] + +# All elements that coverd (in mapping.json 1-108) +ALL_ELEMENTS = [ + "H", + "He", + "Li", + "Be", + "B", + "C", + "N", + "O", + "F", + "Ne", + "Na", + "Mg", + "Al", + "Si", + "P", + "S", + "Cl", + "Ar", + "K", + "Ca", + "Sc", + "Ti", + "V", + "Cr", + "Mn", + "Fe", + "Co", + "Ni", + "Cu", + "Zn", + "Ga", + "Ge", + "As", + "Se", + "Br", + "Kr", + "Rb", + "Sr", + "Y", + "Zr", + "Nb", + "Mo", + "Tc", + "Ru", + "Rh", + "Pd", + "Ag", + "Cd", + "In", + "Sn", + "Sb", + "Te", + "I", + "Xe", + "Cs", + "Ba", + "La", + "Ce", + "Pr", + "Nd", + "Pm", + "Sm", + "Eu", + "Gd", + "Tb", + "Dy", + "Ho", + "Er", + "Tm", + "Yb", + "Lu", + "Hf", + "Ta", + "W", + "Re", + "Os", + "Ir", + "Pt", + "Au", + "Hg", + "Tl", + "Pb", + "Bi", + "Po", + "At", + "Rn", + "Fr", + "Ra", + "Ac", + "Th", + "Pa", + "U", + "Np", + "Pu", + "Am", + "Cm", + "Bk", + "Cf", + "Es", + "Fm", + "Md", + "No", + "Lr", + "Rf", + "Db", + "Sg", + "Bh", + "Hs", +] diff --git a/src/aiida_sssp_workflow/workflows/convergence/bands.py b/src/aiida_sssp_workflow/workflows/convergence/bands.py index 04a23627..af342a0f 100644 --- a/src/aiida_sssp_workflow/workflows/convergence/bands.py +++ b/src/aiida_sssp_workflow/workflows/convergence/bands.py @@ -142,7 +142,7 @@ def prepare_evaluate_builder(self, ecutwfc, ecutrho): # Generic builder.kpoints_distance_bands = orm.Float(protocol["kpoints_distance"]) - builder.init_nbands_factor = orm.Float(protocol["init_nbands_factor"]) + builder.init_nbands_factor = orm.Int(protocol["init_nbands_factor"]) builder.fermi_shift = orm.Float(protocol["fermi_shift"]) builder.run_band_structure = orm.Bool(False) diff --git a/src/aiida_sssp_workflow/workflows/evaluate/_bands.py b/src/aiida_sssp_workflow/workflows/evaluate/_bands.py index f6210bcc..6a657e06 100644 --- a/src/aiida_sssp_workflow/workflows/evaluate/_bands.py +++ b/src/aiida_sssp_workflow/workflows/evaluate/_bands.py @@ -66,7 +66,7 @@ class BandsWorkChain(_BaseEvaluateWorkChain): # to prevent the infinite loop in bands evaluation # If start from 1.5, and increase 2.0 every time, it will reach 11.5 at most. _MAX_INCREMENT_BANDS_FACTOR = 10 - _BANDS_FACTOR_INCREASE_STEP = 2.0 + _BANDS_FACTOR_INCREASE_STEP = 2 @classmethod def define(cls, spec): @@ -76,7 +76,7 @@ def define(cls, spec): spec.expose_inputs(PwBandsWorkChain, include=['scf', 'bands', 'structure']) spec.input('kpoints_distance_bands', valid_type=orm.Float, help='Kpoints distance setting for bulk energy bands calculation.') - spec.input('init_nbands_factor', valid_type=orm.Float, default=lambda: orm.Float(1.5), + spec.input('init_nbands_factor', valid_type=orm.Int, default=lambda: orm.Int(2), help='initial nbands factor.') spec.input('fermi_shift', valid_type=orm.Float, default=lambda: orm.Float(10.0), help='The uplimit of energy to check the bands diff, control the number of bands.') @@ -117,7 +117,7 @@ def define(cls, spec): def setup(self): """Input validation""" # set initial lowest highest bands eigenvalue - fermi_energy equals to 0.0 - self.ctx.nbands_factor = self.inputs.init_nbands_factor.value + self.ctx.nbands_factor = int(self.inputs.init_nbands_factor.value) # For qe PwBandsWorkChain if `bands_kpoints` not set, the seekpath will run # to give a seekpath along the recommonded path. @@ -141,6 +141,9 @@ def should_run_bands(self) -> bool: def run_bands(self): """run bands calculation""" inputs = self.exposed_inputs(PwBandsWorkChain) + inputs["metadata"] = { + "call_link_label": f"bands_with_factor_{self.ctx.nbands_factor}" + } inputs["nbands_factor"] = orm.Float(self.ctx.nbands_factor) inputs["bands_kpoints"] = self.ctx.bands_kpoints @@ -215,6 +218,7 @@ def should_run_band_structure(self): def run_band_structure(self): """run band structure calculation""" inputs = self.exposed_inputs(PwBandsWorkChain) + inputs["metadata"] = {"call_link_label": "band_structure"} inputs["nbands_factor"] = orm.Float(self.ctx.nbands_factor) inputs["bands_kpoints_distance"] = orm.Float( self.inputs.kpoints_distance_band_structure diff --git a/src/aiida_sssp_workflow/workflows/measure/__init__.py b/src/aiida_sssp_workflow/workflows/measure/__init__.py index 321257b7..c5d4fa4b 100644 --- a/src/aiida_sssp_workflow/workflows/measure/__init__.py +++ b/src/aiida_sssp_workflow/workflows/measure/__init__.py @@ -23,12 +23,8 @@ def define(cls, spec): help='The `pw.x` code use for the `PwCalculation`.') spec.input('pseudo', valid_type=UpfData, required=True, help='Pseudopotential to be verified') - spec.input('oxygen_pseudo', valid_type=UpfData, required=True) - spec.input('oxygen_ecutwfc', valid_type=orm.Float, required=True) - spec.input('oxygen_ecutrho', valid_type=orm.Float, required=True) spec.input('protocol', valid_type=orm.Str, required=True, help='The protocol which define input calculation parameters.') - spec.input('configurations', valid_type=orm.List, required=False) spec.input('wavefunction_cutoff', valid_type=orm.Float, required=True, help='The wavefunction cutoff.') spec.input('charge_density_cutoff', valid_type=orm.Float, required=True, help='The charge density cutoff.') spec.input( diff --git a/src/aiida_sssp_workflow/workflows/measure/bands copy.py b/src/aiida_sssp_workflow/workflows/measure/bands copy.py new file mode 100644 index 00000000..f389c06f --- /dev/null +++ b/src/aiida_sssp_workflow/workflows/measure/bands copy.py @@ -0,0 +1,233 @@ +# -*- coding: utf-8 -*- +""" +Bands distance of many input pseudos +""" + +from aiida import orm +from aiida.engine import ToContext, if_ +from aiida.plugins import DataFactory + +from aiida_sssp_workflow.utils import ( + LANTHANIDE_ELEMENTS, + MAGNETIC_ELEMENTS, + NONMETAL_ELEMENTS, + get_magnetic_inputs, + get_protocol, + get_standard_structure, + reset_pseudos_for_magnetic, + update_dict, +) +from aiida_sssp_workflow.workflows.common import ( + get_extra_parameters_for_lanthanides, + get_pseudo_N, +) +from aiida_sssp_workflow.workflows.evaluate._bands import BandsWorkChain +from aiida_sssp_workflow.workflows.measure import _BaseMeasureWorkChain + +UpfData = DataFactory("pseudo.upf") + + +def validate_input_pseudos(d_pseudos, _): + """Validate that all input pseudos map to same element""" + element = set(pseudo.element for pseudo in d_pseudos.values()) + + if len(element) > 1: + return f"The pseudos corespond to different elements {element}." + + +class BandsMeasureWorkChain(_BaseMeasureWorkChain): + """ + WorkChain to run bands measure, + run without sym for distance compare and band structure along the path + """ + + @classmethod + def define(cls, spec): + """Define the process specification.""" + # yapf: disable + super().define(spec) + + spec.outline( + cls.setup, + if_(cls.is_magnetic_element)( + cls.extra_setup_for_magnetic_element, + ), + if_(cls.is_lanthanide_element)( + cls.extra_setup_for_lanthanide_element, ), + cls.setup_pw_parameters_from_protocol, + cls.run_bands_evaluation, + cls.finalize, + ) + + spec.expose_outputs(BandsWorkChain) + + def setup(self): + """ + This step contains all preparation before actaul setup, e.g. set + the context of element, base_structure, base pw_parameters and pseudos. + """ + # parse pseudo and output its header information + self.ctx.pw_parameters = {} + + element = self.ctx.element = self.inputs.pseudo.element + self.ctx.pseudos = {self.ctx.element: self.inputs.pseudo} + + self.ctx.structure = get_standard_structure( + element, + prop="bands", + ) + + # extra setting for bands convergence + self.ctx.is_metal = element not in NONMETAL_ELEMENTS + + # set up the ecutwfc and ecutrho + self.ctx.ecutwfc = self.inputs.wavefunction_cutoff.value + self.ctx.ecutrho = self.inputs.charge_density_cutoff.value + + def is_magnetic_element(self): + """Check if the element is magnetic""" + return self.ctx.element in MAGNETIC_ELEMENTS + + def extra_setup_for_magnetic_element(self): + """Extra setup for magnetic element""" + self.ctx.structure, magnetic_extra_parameters = get_magnetic_inputs( + self.ctx.structure + ) + self.ctx.pw_parameters = update_dict( + self.ctx.pw_parameters, magnetic_extra_parameters + ) + + # override pseudos setting + # required for O, Mn, Cr where the kind names varies for sites + self.ctx.pseudos = reset_pseudos_for_magnetic( + self.inputs.pseudo, self.ctx.structure + ) + + def is_lanthanide_element(self): + """Check if the element is rare earth""" + return self.ctx.element in LANTHANIDE_ELEMENTS + + def extra_setup_for_lanthanide_element(self): + """Extra setup for rare earth element""" + nbnd_factor = 2.0 + pseudo_N = get_pseudo_N() + self.ctx.pseudos["N"] = pseudo_N + pseudo_RE = self.inputs.pseudo + nbnd = nbnd_factor * (pseudo_N.z_valence + pseudo_RE.z_valence) + pw_parameters = get_extra_parameters_for_lanthanides(self.ctx.element, nbnd) + + self.ctx.pw_parameters = update_dict(self.ctx.pw_parameters, pw_parameters) + + def setup_pw_parameters_from_protocol(self): + """Input validation""" + # pylint: disable=invalid-name, attribute-defined-outside-init + + # Read from protocol if parameters not set from inputs + protocol = get_protocol(category="bands", name=self.inputs.protocol.value) + self._DEGAUSS = protocol["degauss"] + self._OCCUPATIONS = protocol["occupations"] + self._SMEARING = protocol["smearing"] + self._CONV_THR_PER_ATOM = protocol["conv_thr_per_atom"] + + self._INIT_NBANDS_FACTOR = protocol["init_nbands_factor"] + self._FERMI_SHIFT = protocol["fermi_shift"] + + self.ctx.init_nbands_factor = self._INIT_NBANDS_FACTOR + self.ctx.fermi_shift = self._FERMI_SHIFT + + self.ctx.kpoints_distance_scf = protocol["kpoints_distance_scf"] + self.ctx.kpoints_distance_bands = protocol["kpoints_distance_bands"] + self.ctx.kpoints_distance_band_structure = protocol[ + "kpoints_distance_band_structure" + ] + + parameters = { + "SYSTEM": { + "degauss": self._DEGAUSS, + "occupations": self._OCCUPATIONS, + "smearing": self._SMEARING, + }, + "ELECTRONS": { + "conv_thr": self._CONV_THR_PER_ATOM, + }, + "CONTROL": { + "calculation": "scf", + }, + } + + self.ctx.pw_parameters = update_dict(self.ctx.pw_parameters, parameters) + + self.logger.info( + f"The pw parameters for convergence is: {self.ctx.pw_parameters}" + ) + + def _get_inputs(self, pseudos): + """ + get inputs for the bands evaluation with given pseudo + """ + ecutwfc, ecutrho = self._get_pw_cutoff( + self.ctx.structure, self.ctx.ecutwfc, self.ctx.ecutrho + ) + + parameters = { + "SYSTEM": { + "ecutwfc": round(ecutwfc, 1), + "ecutrho": round(ecutrho, 1), + }, + } + parameters = update_dict(parameters, self.ctx.pw_parameters) + + parameters_bands = update_dict(parameters, {}) + parameters_bands["SYSTEM"].pop("nbnd", None) + parameters_bands["CONTROL"]["calculation"] = "bands" + + inputs = { + "structure": self.ctx.structure, + "scf": { + "pw": { + "code": self.inputs.code, + "pseudos": pseudos, + "parameters": orm.Dict(dict=parameters), + "metadata": { + "options": self.inputs.options.get_dict(), + }, + "parallelization": self.inputs.parallelization, + }, + "kpoints_distance": orm.Float(self.ctx.kpoints_distance_scf), + }, + "bands": { + "pw": { + "code": self.inputs.code, + "pseudos": pseudos, + "parameters": orm.Dict(dict=parameters_bands), + "metadata": { + "options": self.inputs.options.get_dict(), + }, + "parallelization": self.inputs.parallelization, + }, + }, + "kpoints_distance_bands": orm.Float(self.ctx.kpoints_distance_bands), + "init_nbands_factor": orm.Float(self.ctx.init_nbands_factor), + "fermi_shift": orm.Float(self.ctx.fermi_shift), + "run_band_structure": orm.Bool(True), + "kpoints_distance_band_structure": orm.Float( + self.ctx.kpoints_distance_band_structure + ), + "clean_workdir": self.inputs.clean_workdir, + } + + return inputs + + def run_bands_evaluation(self): + """run bands evaluation of psp in inputs list""" + inputs = self._get_inputs(self.ctx.pseudos) + + running = self.submit(BandsWorkChain, **inputs) + + self.report(f"launching BandsWorkChain<{running.pk}>") + + return ToContext(bands=running) + + def finalize(self): + """inspect bands run results""" + self.out_many(self.exposed_outputs(self.ctx.bands, BandsWorkChain)) diff --git a/src/aiida_sssp_workflow/workflows/measure/bands.py b/src/aiida_sssp_workflow/workflows/measure/bands.py index f389c06f..f045846e 100644 --- a/src/aiida_sssp_workflow/workflows/measure/bands.py +++ b/src/aiida_sssp_workflow/workflows/measure/bands.py @@ -3,26 +3,27 @@ Bands distance of many input pseudos """ +from pathlib import Path + from aiida import orm -from aiida.engine import ToContext, if_ from aiida.plugins import DataFactory +from aiida.engine import ToContext, ProcessBuilder from aiida_sssp_workflow.utils import ( - LANTHANIDE_ELEMENTS, - MAGNETIC_ELEMENTS, - NONMETAL_ELEMENTS, - get_magnetic_inputs, get_protocol, get_standard_structure, - reset_pseudos_for_magnetic, - update_dict, ) -from aiida_sssp_workflow.workflows.common import ( - get_extra_parameters_for_lanthanides, - get_pseudo_N, +from aiida_sssp_workflow.utils import get_default_mpi_options, extract_pseudo_info +from aiida_sssp_workflow.utils.structure import ( + UNARIE_CONFIGURATIONS, + get_default_configuration, +) +from aiida_sssp_workflow.utils.element import UNSUPPORTED_ELEMENTS +from aiida_sssp_workflow.workflows.evaluate._bands import ( + BandsWorkChain as EvaluateBandsWorkChain, ) -from aiida_sssp_workflow.workflows.evaluate._bands import BandsWorkChain from aiida_sssp_workflow.workflows.measure import _BaseMeasureWorkChain +from aiida_sssp_workflow.workflows.measure.report import BandStructureReport UpfData = DataFactory("pseudo.upf") @@ -35,199 +36,315 @@ def validate_input_pseudos(d_pseudos, _): return f"The pseudos corespond to different elements {element}." -class BandsMeasureWorkChain(_BaseMeasureWorkChain): - """ - WorkChain to run bands measure, +def is_valid_convergence_configuration(value, _=None): + """Check if the configuration is valid""" + # XXX: I am duplicate from _base.py, combine us + valid_configurations = UNARIE_CONFIGURATIONS + if value not in valid_configurations: + return f"Configuration {value} is not valid. Valid configurations are {valid_configurations}" + + +class BandStructureWorkChain(_BaseMeasureWorkChain): + """WorkChain to run bands measure, run without sym for distance compare and band structure along the path """ + _EVALUATE_WORKCHAIN = EvaluateBandsWorkChain + @classmethod def define(cls, spec): """Define the process specification.""" # yapf: disable super().define(spec) + spec.input( + "configuration", + valid_type=orm.Str, + required=False, + validator=is_valid_convergence_configuration, + help="The configuration to use for the workchain, can be DC/BCC/FCC/SC.", + ) spec.outline( - cls.setup, - if_(cls.is_magnetic_element)( - cls.extra_setup_for_magnetic_element, - ), - if_(cls.is_lanthanide_element)( - cls.extra_setup_for_lanthanide_element, ), - cls.setup_pw_parameters_from_protocol, - cls.run_bands_evaluation, - cls.finalize, + cls._setup_pseudos, + cls._setup_protocol, + cls._setup_structure, + cls.run_bands, + cls.inspect_bands, + cls._finalize, ) - spec.expose_outputs(BandsWorkChain) - - def setup(self): - """ - This step contains all preparation before actaul setup, e.g. set - the context of element, base_structure, base pw_parameters and pseudos. - """ - # parse pseudo and output its header information - self.ctx.pw_parameters = {} + spec.expose_outputs(EvaluateBandsWorkChain) + spec.output( + "configuration", + valid_type=orm.Str, + required=True, + help="The configuration used for the convergence.", + ) + spec.output( + "structure", + valid_type=orm.StructureData, + required=True, + help="The structure used for the convergence.", + ) + spec.output( # XXX: I am same as transferibility workchain. combine me. + "report", + valid_type=orm.Dict, + required=True, + help="The output report of convergence verification, it is a dict contains the full information of convergence test, the mapping of cutoffs to the UUID of the evaluation workchain etc.", + ) - element = self.ctx.element = self.inputs.pseudo.element + def _setup_pseudos(self): + """Setup pseudos""" + # XXX: this is same as trans WF, consider combine + pseudo_info = extract_pseudo_info( + self.inputs.pseudo.get_content(), + ) + self.ctx.element = pseudo_info.element self.ctx.pseudos = {self.ctx.element: self.inputs.pseudo} + @property + def element(self): + """Syntax sugar for self.ctx.element""" + # TODO: same as from convergence/_base, consider combine + if "element" not in self.ctx: + raise AttributeError( + "element is not set in the context, your step must after _setup_pseudos" + ) + + return self.ctx.element + + @property + def pseudos(self): + """Syntax sugar for self.ctx.pseudos""" + # TODO: same as from convergence/_base, consider combine + if "pseudos" not in self.ctx: + raise AttributeError( + "pseudos is not set in the context, your step must after _setup_pseudos" + ) + + return self.ctx.pseudos + + def _setup_protocol(self): + """unzip and parse protocol parameters to context""" + # XXX: this is same as trans WF, consider combine + protocol = get_protocol(category="bands", name=self.inputs.protocol.value) + self.ctx.protocol = protocol + + @property + def protocol(self): + """Syntax sugar for self.ctx.protocol""" + # XXX: this is same as trans WF, consider combine + if "protocol" not in self.ctx: + raise AttributeError( + "protocol is not set in the context, your step must after _setup_protocol" + ) + + return self.ctx.protocol + + def _setup_structure(self): + """Set up the configuration to be run for band structure""" + # We only has unaries from Z=1 (hydrogen) to Z=96 (curium), so raise an exception + # if larger Z elements e.g from Bk comes to valified + if self.element in UNSUPPORTED_ELEMENTS: + return self.exit_codes.ERROR_UNSUPPORTED_ELEMENT.format(self.element) + + if "configuration" in self.inputs: + configuration = self.inputs.configuration + else: + # will use the default configuration set in the protocol (mapping.json) + configuration = get_default_configuration( + self.ctx.element, property="bands" + ) + self.ctx.structure = get_standard_structure( - element, - prop="bands", + self.ctx.element, configuration=configuration ) - # extra setting for bands convergence - self.ctx.is_metal = element not in NONMETAL_ELEMENTS - - # set up the ecutwfc and ecutrho - self.ctx.ecutwfc = self.inputs.wavefunction_cutoff.value - self.ctx.ecutrho = self.inputs.charge_density_cutoff.value + self.out("configuration", configuration) + self.out("structure", self.ctx.structure) - def is_magnetic_element(self): - """Check if the element is magnetic""" - return self.ctx.element in MAGNETIC_ELEMENTS + @property + def structure(self): + """Syntax sugar to get self structure""" + return self.ctx.structure - def extra_setup_for_magnetic_element(self): - """Extra setup for magnetic element""" - self.ctx.structure, magnetic_extra_parameters = get_magnetic_inputs( - self.ctx.structure - ) - self.ctx.pw_parameters = update_dict( - self.ctx.pw_parameters, magnetic_extra_parameters + @classmethod + def get_builder( + cls, + pseudo: Path | UpfData, + protocol: str, + wavefunction_cutoff: float, + charge_density_cutoff: float, + code: orm.AbstractCode, + configuration: list | None = None, + parallelization: dict | None = None, + mpi_options: dict | None = None, + clean_workdir: bool = True, # default to clean workdir + ) -> ProcessBuilder: + """Return a builder to run this EOS convergence workchain""" + builder = super().get_builder() + builder.protocol = orm.Str(protocol) + if isinstance(pseudo, Path): + builder.pseudo = UpfData.get_or_create(pseudo) + else: + builder.pseudo = pseudo + builder.wavefunction_cutoff = orm.Float(wavefunction_cutoff) + builder.charge_density_cutoff = orm.Float(charge_density_cutoff) + builder.code = code + + if configuration is not None: + builder.configuration = orm.Str(configuration) + + # Set the default label and description + # The default label is set to be the base file name of PP + # The description include which configuration and which protocol is using. + builder.metadata.call_link_label = "band_structure_verification" + builder.metadata.label = ( + pseudo.filename if isinstance(pseudo, UpfData) else pseudo.name ) - - # override pseudos setting - # required for O, Mn, Cr where the kind names varies for sites - self.ctx.pseudos = reset_pseudos_for_magnetic( - self.inputs.pseudo, self.ctx.structure + builder.metadata.description = ( + f"""Run on protocol '{protocol}' | configuration '{configuration if configuration is not None else "default"}' | """ + f" base (ecutwfc, ecutrho) = ({wavefunction_cutoff}, {charge_density_cutoff})" ) + builder.clean_workdir = orm.Bool(clean_workdir) + builder.code = code - def is_lanthanide_element(self): - """Check if the element is rare earth""" - return self.ctx.element in LANTHANIDE_ELEMENTS + if parallelization: + builder.parallelization = orm.Dict(parallelization) + else: + builder.parallelization = orm.Dict() - def extra_setup_for_lanthanide_element(self): - """Extra setup for rare earth element""" - nbnd_factor = 2.0 - pseudo_N = get_pseudo_N() - self.ctx.pseudos["N"] = pseudo_N - pseudo_RE = self.inputs.pseudo - nbnd = nbnd_factor * (pseudo_N.z_valence + pseudo_RE.z_valence) - pw_parameters = get_extra_parameters_for_lanthanides(self.ctx.element, nbnd) + if mpi_options: + builder.mpi_options = orm.Dict(mpi_options) + else: + builder.mpi_options = orm.Dict(get_default_mpi_options()) - self.ctx.pw_parameters = update_dict(self.ctx.pw_parameters, pw_parameters) - - def setup_pw_parameters_from_protocol(self): - """Input validation""" - # pylint: disable=invalid-name, attribute-defined-outside-init + return builder + def prepare_evaluate_builder(self): + """Get inputs for the bands evaluation with given pseudo""" # Read from protocol if parameters not set from inputs - protocol = get_protocol(category="bands", name=self.inputs.protocol.value) - self._DEGAUSS = protocol["degauss"] - self._OCCUPATIONS = protocol["occupations"] - self._SMEARING = protocol["smearing"] - self._CONV_THR_PER_ATOM = protocol["conv_thr_per_atom"] + ecutwfc = self.inputs.wavefunction_cutoff.value + ecutrho = self.inputs.charge_density_cutoff.value + + protocol = self.protocol - self._INIT_NBANDS_FACTOR = protocol["init_nbands_factor"] - self._FERMI_SHIFT = protocol["fermi_shift"] + builder = self._EVALUATE_WORKCHAIN.get_builder() - self.ctx.init_nbands_factor = self._INIT_NBANDS_FACTOR - self.ctx.fermi_shift = self._FERMI_SHIFT + builder.clean_workdir = ( + self.inputs.clean_workdir + ) # sync with the main workchain - self.ctx.kpoints_distance_scf = protocol["kpoints_distance_scf"] - self.ctx.kpoints_distance_bands = protocol["kpoints_distance_bands"] - self.ctx.kpoints_distance_band_structure = protocol[ - "kpoints_distance_band_structure" - ] + builder.structure = self.structure + natoms = len(self.structure.sites) - parameters = { + scf_pw_parameters = { + "CONTROL": { + "calculation": "scf", + }, "SYSTEM": { - "degauss": self._DEGAUSS, - "occupations": self._OCCUPATIONS, - "smearing": self._SMEARING, + "degauss": protocol["degauss"], + "occupations": protocol["occupations"], + "smearing": protocol["smearing"], + "ecutwfc": round(ecutwfc), + "ecutrho": round(ecutrho), }, "ELECTRONS": { - "conv_thr": self._CONV_THR_PER_ATOM, + "conv_thr": protocol["conv_thr_per_atom"] * natoms, + "mixing_beta": protocol["mixing_beta"], }, + } + builder.scf.pw["code"] = self.inputs.code + builder.scf.pw["pseudos"] = self.pseudos + builder.scf.pw["parameters"] = orm.Dict(scf_pw_parameters) + builder.scf.pw["parallelization"] = self.inputs.parallelization + builder.scf.pw["metadata"]["options"] = self.inputs.mpi_options.get_dict() + builder.scf.kpoints_distance = orm.Float(protocol["kpoints_distance"]) + + bands_pw_parameters = { "CONTROL": { - "calculation": "scf", + "calculation": "bands", }, - } - - self.ctx.pw_parameters = update_dict(self.ctx.pw_parameters, parameters) - - self.logger.info( - f"The pw parameters for convergence is: {self.ctx.pw_parameters}" - ) - - def _get_inputs(self, pseudos): - """ - get inputs for the bands evaluation with given pseudo - """ - ecutwfc, ecutrho = self._get_pw_cutoff( - self.ctx.structure, self.ctx.ecutwfc, self.ctx.ecutrho - ) - - parameters = { "SYSTEM": { - "ecutwfc": round(ecutwfc, 1), - "ecutrho": round(ecutrho, 1), + "degauss": protocol["degauss"], + "occupations": protocol["occupations"], + "smearing": protocol["smearing"], + "ecutwfc": round(ecutwfc), + "ecutrho": round(ecutrho), }, - } - parameters = update_dict(parameters, self.ctx.pw_parameters) - - parameters_bands = update_dict(parameters, {}) - parameters_bands["SYSTEM"].pop("nbnd", None) - parameters_bands["CONTROL"]["calculation"] = "bands" - - inputs = { - "structure": self.ctx.structure, - "scf": { - "pw": { - "code": self.inputs.code, - "pseudos": pseudos, - "parameters": orm.Dict(dict=parameters), - "metadata": { - "options": self.inputs.options.get_dict(), - }, - "parallelization": self.inputs.parallelization, - }, - "kpoints_distance": orm.Float(self.ctx.kpoints_distance_scf), - }, - "bands": { - "pw": { - "code": self.inputs.code, - "pseudos": pseudos, - "parameters": orm.Dict(dict=parameters_bands), - "metadata": { - "options": self.inputs.options.get_dict(), - }, - "parallelization": self.inputs.parallelization, - }, + "ELECTRONS": { + "conv_thr": protocol["conv_thr_per_atom"] * natoms, + "mixing_beta": protocol["mixing_beta"], }, - "kpoints_distance_bands": orm.Float(self.ctx.kpoints_distance_bands), - "init_nbands_factor": orm.Float(self.ctx.init_nbands_factor), - "fermi_shift": orm.Float(self.ctx.fermi_shift), - "run_band_structure": orm.Bool(True), - "kpoints_distance_band_structure": orm.Float( - self.ctx.kpoints_distance_band_structure - ), - "clean_workdir": self.inputs.clean_workdir, } - return inputs + builder.bands.pw["code"] = self.inputs.code + builder.bands.pw["pseudos"] = self.pseudos + builder.bands.pw["parameters"] = orm.Dict(bands_pw_parameters) + builder.bands.pw["parallelization"] = self.inputs.parallelization + builder.bands.pw["metadata"]["options"] = self.inputs.mpi_options.get_dict() + + # Generic + builder.kpoints_distance_bands = orm.Float(protocol["kpoints_distance"]) + builder.kpoints_distance_band_structure = orm.Float( + protocol["kpoints_distance_bs"] + ) + builder.init_nbands_factor = orm.Int(protocol["init_nbands_factor"]) + builder.fermi_shift = orm.Float(protocol["fermi_shift"]) + builder.run_band_structure = orm.Bool(True) + + return builder - def run_bands_evaluation(self): - """run bands evaluation of psp in inputs list""" - inputs = self._get_inputs(self.ctx.pseudos) + def run_bands(self): + """run bands evaluation""" + builder = self.prepare_evaluate_builder() - running = self.submit(BandsWorkChain, **inputs) + running = self.submit(builder) self.report(f"launching BandsWorkChain<{running.pk}>") return ToContext(bands=running) - def finalize(self): + def inspect_bands(self): """inspect bands run results""" - self.out_many(self.exposed_outputs(self.ctx.bands, BandsWorkChain)) + bands_workchain = self.ctx.bands + self.out_many(self.exposed_outputs(bands_workchain, EvaluateBandsWorkChain)) + + outgoing: orm.LinkManager = bands_workchain.base.links.get_outgoing() + band_structure_node = outgoing.get_node_by_label("band_structure") + + # The band node into record is the one with largest bands_factor + # ['band_with_factor_x'] -> {'band_with_factor_x': x} and sort + bands_label = sorted( + { + k: int(k.split("_")[-1]) + for k in outgoing.all_link_labels() + if "bands_with_factor" in k + }, + key=lambda x: x[1], + )[-1] + bands_node = outgoing.get_node_by_label(bands_label) + + band_dict = { + "bands": { + "uuid": bands_node.uuid, + "exit_status": bands_node.exit_status, + }, + "band_structure": { + "uuid": band_structure_node.uuid, + "exit_status": band_structure_node.exit_status, + }, + } + + try: + validated_report = BandStructureReport.construct(band_dict) + self.report("BandStructureReport report is validated") + except Exception as e: + self.report(f"BandStructureReport in sot validated: {e}") + raise e + else: + self.out("report", orm.Dict(dict=validated_report.model_dump()).store()) + + def _finalize(self): + """Final""" + # TODO: diff --git a/src/aiida_sssp_workflow/workflows/measure/report.py b/src/aiida_sssp_workflow/workflows/measure/report.py index f0621aa9..27874779 100644 --- a/src/aiida_sssp_workflow/workflows/measure/report.py +++ b/src/aiida_sssp_workflow/workflows/measure/report.py @@ -22,3 +22,19 @@ def validate_eos_dict(cls, d): if not all(k in VALID_CONFIGURATIONS for k in d.keys()): raise ValueError(f"configuration should be one of {VALID_CONFIGURATIONS}") return d + + +class SingleBandEntry(BaseModel): + uuid: str + exit_status: int + + +class BandStructureReport(BaseModel): + bands: SingleBandEntry + band_structure: SingleBandEntry + + @classmethod + def construct(cls, band_dict: dict[str, dict]): + """Construct the BandStructureReport from dict data, one kv for bands, one for band structure""" + + return cls(**{k: SingleBandEntry(**v) for (k, v) in band_dict.items()}) diff --git a/src/aiida_sssp_workflow/workflows/measure/transferability.py b/src/aiida_sssp_workflow/workflows/measure/transferability.py index 635e1ab9..e70cd669 100644 --- a/src/aiida_sssp_workflow/workflows/measure/transferability.py +++ b/src/aiida_sssp_workflow/workflows/measure/transferability.py @@ -48,6 +48,10 @@ class EOSTransferabilityWorkChain(_BaseMeasureWorkChain): def define(cls, spec): """Define the process specification.""" super().define(spec) + spec.input("oxygen_pseudo", valid_type=UpfData, required=True) + spec.input("oxygen_ecutwfc", valid_type=orm.Float, required=True) + spec.input("oxygen_ecutrho", valid_type=orm.Float, required=True) + spec.input("configurations", valid_type=orm.List, required=False) spec.outline( cls._setup_pseudos, diff --git a/tests/utils/test_structure.py b/tests/utils/test_structure.py index 166d067d..e6444456 100644 --- a/tests/utils/test_structure.py +++ b/tests/utils/test_structure.py @@ -1,8 +1,13 @@ import pytest +import itertools + from aiida.engine import run_get_node from aiida import orm from aiida_sssp_workflow.utils import get_default_configuration, get_standard_structure +from aiida_sssp_workflow.utils.element import ALL_ELEMENTS, UNSUPPORTED_ELEMENTS + +ALL_SUPPORTED_ELEMENTS = [e for e in ALL_ELEMENTS if e not in UNSUPPORTED_ELEMENTS] @pytest.mark.parametrize( @@ -44,3 +49,48 @@ def test_get_standard_structure(element, configuration, data_regression): ) data_regression.check(r.get_cif().get_content()) + + +@pytest.mark.parametrize( + "element, target_property", + [ + (e, p) + for e, p in itertools.product(ALL_SUPPORTED_ELEMENTS, ["band", "convergence"]) + ], +) +def test_get_standard_structure_for_available_properties(element, target_property): + """Go through all elements covered configuration from get_default_configuration""" + # Loop over all elements + configuration = get_default_configuration( + element, + property=target_property, + ) + + structure = get_standard_structure( + element, + configuration=configuration, + ) + + assert isinstance(structure, orm.StructureData) + + +@pytest.mark.parametrize( + "element, target_property", + [ + (e, p) + for e, p in itertools.product(UNSUPPORTED_ELEMENTS, ["band", "convergence"]) + ], +) +def test_get_standard_structure_raise_for_unsupported(element, target_property): + """Go through all elements covered configuration from get_default_configuration""" + # Loop over all elements + configuration = get_default_configuration( + element, + property=target_property, + ) + + with pytest.raises(ValueError, match="Unknown configuration N/A"): + get_standard_structure( + element, + configuration=configuration, + ) diff --git a/tests/workflows/convergence/test_caching.py b/tests/workflows/convergence/test_caching.py index b8e7820b..5f63269d 100644 --- a/tests/workflows/convergence/test_caching.py +++ b/tests/workflows/convergence/test_caching.py @@ -37,6 +37,7 @@ def test_caching_bands( # check the first scf of reference # The pw calculation + # XXX: use link label to find the expected calcjob node source_ref_wf = [ p for p in source_node.called diff --git a/tests/workflows/measure/test_bands.py b/tests/workflows/measure/test_bands.py new file mode 100644 index 00000000..37a1fe66 --- /dev/null +++ b/tests/workflows/measure/test_bands.py @@ -0,0 +1,77 @@ +import pytest + +from aiida import orm +from aiida.plugins import DataFactory, WorkflowFactory +from aiida.engine import ProcessBuilder, run_get_node + +from aiida_sssp_workflow.workflows.measure.report import BandStructureReport + +UpfData = DataFactory("pseudo.upf") + + +@pytest.mark.slow +def test_run_default(pseudo_path, code_generator, serialize_inputs, data_regression): + """Test band structure verification workflow""" + _WorkChain = WorkflowFactory("sssp_workflow.measure.bands") + + builder: ProcessBuilder = _WorkChain.get_builder( + pseudo=pseudo_path("Al"), + protocol="test", + configuration="SC", + wavefunction_cutoff=25, + charge_density_cutoff=100, + code=code_generator("pw"), + clean_workdir=True, + ) + + # run the workchain + result, node = run_get_node(builder) + + assert node.is_finished_ok + assert node.label == "Al.paw.pbe.z_3.ld1.psl.v0.1.upf" + assert "SC" in node.description + assert "(25, 100)" in node.description + assert "test" in node.description + + assert "report" in result + assert "bands" in result + assert "band_structure" in result + + report = BandStructureReport.construct(result["report"].get_dict()) + bands_node = orm.load_node(report.bands.uuid) + band_structure_node = orm.load_node(report.band_structure.uuid) + + pw_parameters_bands = bands_node.called[1].called[0].inputs.parameters + assert isinstance(pw_parameters_bands["SYSTEM"]["ecutwfc"], int) + assert pw_parameters_bands["SYSTEM"]["ecutwfc"] == 25 + assert pw_parameters_bands["SYSTEM"]["ecutrho"] == 100 + + pw_parameters_bs = band_structure_node.called[2].called[0].inputs.parameters + assert isinstance(pw_parameters_bs["SYSTEM"]["ecutwfc"], int) + assert pw_parameters_bs["SYSTEM"]["ecutwfc"] == 25 + assert pw_parameters_bs["SYSTEM"]["ecutrho"] == 100 + + data_regression.check(serialize_inputs(band_structure_node.inputs)) + + +@pytest.mark.parametrize("clean_workdir", [True, False]) +def test_builder_default_args_passing( + clean_workdir, + pseudo_path, + code_generator, + serialize_builder, + data_regression, +): + """Test transferability workflow builder is correctly created with default args""" + _WorkChain = WorkflowFactory("sssp_workflow.measure.bands") + + builder: ProcessBuilder = _WorkChain.get_builder( + pseudo=pseudo_path("Al"), + protocol="test", + wavefunction_cutoff=25, + charge_density_cutoff=100, + code=code_generator("pw"), + clean_workdir=clean_workdir, + ) + + data_regression.check(serialize_builder(builder)) diff --git a/tests/workflows/measure/test_bands/test_builder_default_args_passing_False_.yml b/tests/workflows/measure/test_bands/test_builder_default_args_passing_False_.yml new file mode 100644 index 00000000..066883de --- /dev/null +++ b/tests/workflows/measure/test_bands/test_builder_default_args_passing_False_.yml @@ -0,0 +1,17 @@ +charge_density_cutoff: 100.0 +clean_workdir: false +code: pw-docker@localhost +metadata: + call_link_label: band_structure_verification + description: Run on protocol 'test' | configuration 'default' | base (ecutwfc, + ecutrho) = (25, 100) + label: Al.paw.pbe.z_3.ld1.psl.v0.1.upf +mpi_options: + max_wallclock_seconds: 1800 + resources: + num_machines: 1 + withmpi: false +parallelization: {} +protocol: test +pseudo: Al +wavefunction_cutoff: 25.0 diff --git a/tests/workflows/measure/test_bands/test_builder_default_args_passing_True_.yml b/tests/workflows/measure/test_bands/test_builder_default_args_passing_True_.yml new file mode 100644 index 00000000..a23279cc --- /dev/null +++ b/tests/workflows/measure/test_bands/test_builder_default_args_passing_True_.yml @@ -0,0 +1,17 @@ +charge_density_cutoff: 100.0 +clean_workdir: true +code: pw-docker@localhost +metadata: + call_link_label: band_structure_verification + description: Run on protocol 'test' | configuration 'default' | base (ecutwfc, + ecutrho) = (25, 100) + label: Al.paw.pbe.z_3.ld1.psl.v0.1.upf +mpi_options: + max_wallclock_seconds: 1800 + resources: + num_machines: 1 + withmpi: false +parallelization: {} +protocol: test +pseudo: Al +wavefunction_cutoff: 25.0 diff --git a/tests/workflows/measure/test_bands/test_run_default.yml b/tests/workflows/measure/test_bands/test_run_default.yml new file mode 100644 index 00000000..0463c238 --- /dev/null +++ b/tests/workflows/measure/test_bands/test_run_default.yml @@ -0,0 +1,43 @@ +bands: + max_iterations: 5 + pw: + code: pw-docker@localhost + parallelization: {} + parameters: + CONTROL: + calculation: bands + ELECTRONS: + conv_thr: 1.0e-06 + mixing_beta: 0.4 + SYSTEM: + degauss: 0.0045 + ecutrho: 100 + ecutwfc: 25 + occupations: smearing + smearing: fd + pseudos: + Al: Al +bands_kpoints_distance: 0.1 +clean_workdir: false +nbands_factor: 7.0 +scf: + kpoints_distance: 0.25 + max_iterations: 5 + pw: + code: pw-docker@localhost + parallelization: {} + parameters: + CONTROL: + calculation: scf + ELECTRONS: + conv_thr: 1.0e-06 + mixing_beta: 0.4 + SYSTEM: + degauss: 0.0045 + ecutrho: 100 + ecutwfc: 25 + occupations: smearing + smearing: fd + pseudos: + Al: Al +structure: Al diff --git a/tests/workflows/measure/test_report.py b/tests/workflows/measure/test_report.py index 3c79e82a..75a1af98 100644 --- a/tests/workflows/measure/test_report.py +++ b/tests/workflows/measure/test_report.py @@ -1,10 +1,14 @@ import pytest +from pydantic import ValidationError -from aiida_sssp_workflow.workflows.measure.report import TransferabilityReport +from aiida_sssp_workflow.workflows.measure.report import ( + TransferabilityReport, + BandStructureReport, +) -def test_construct_report_entry(generate_uuid): - """Test convergence report model""" +def test_construct_transferibility_report_entry(generate_uuid): + """Test TransferabilityReport report model""" report1 = { "uuid": generate_uuid("1"), "exit_status": 0, @@ -36,3 +40,48 @@ def test_raise_when_configuration_not_valid(generate_uuid): with pytest.raises(ValueError, match="should be one of"): TransferabilityReport.construct({"XO": report1, "WRONG": report2}) + + +def test_construct_band_report_entry(generate_uuid): + """Test band report model""" + uuid1 = generate_uuid("1") + uuid2 = generate_uuid("2") + + bands = { + "uuid": uuid1, + "exit_status": 0, + } + + band_structure = { + "uuid": uuid2, + "exit_status": 0, + } + + expected_report = BandStructureReport.construct( + {"bands": bands, "band_structure": band_structure} + ) + + # regression test + got_report = BandStructureReport.construct(expected_report.model_dump()) + + assert got_report == expected_report + + # Check the field can be reached with desired ref + assert expected_report.bands.uuid == uuid1 + assert expected_report.band_structure.uuid == uuid2 + + +def test_raise_when_key_of_band_report_not_valid(generate_uuid): + report1 = { + "uuid": generate_uuid("1"), + "exit_status": 0, + } + report2 = { + "uuid": generate_uuid("2"), + "exit_status": 0, + } + + with pytest.raises( + ValidationError, match="1 validation error for BandStructureReport" + ): + BandStructureReport.construct({"bands": report1, "WRONG": report2})