Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 9, 2024
1 parent 4dd775d commit 25d9ec6
Show file tree
Hide file tree
Showing 15 changed files with 246 additions and 185 deletions.
23 changes: 18 additions & 5 deletions src/aiida_sssp_workflow/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from aiida.cmdline.params import options, types
from aiida.cmdline.utils import echo
from aiida.engine import ProcessBuilder, run_get_node, submit
from aiida.plugins import DataFactory, WorkflowFactory
from aiida.plugins import WorkflowFactory

from aiida_pseudo.data.pseudo.upf import UpfData
from aiida_sssp_workflow.cli import cmd_root
Expand All @@ -25,6 +25,7 @@

VerificationWorkChain = WorkflowFactory("sssp_workflow.verification")


def guess_properties_list(property: list) -> Tuple[List[str], str]:
# if the property is not specified, use the default list with all properties calculated.
# otherwise, use the specified properties.
Expand All @@ -43,21 +44,27 @@ def guess_properties_list(property: list) -> Tuple[List[str], str]:

return properties_list, extra_desc


def guess_is_convergence(properties_list: list) -> bool:
"""Check if it is a convergence test"""

return any([c for c in properties_list if c.startswith("convergence")])


def guess_is_full_convergence(properties_list: list) -> bool:
"""Check if all properties are run for convergence test"""

return len([c for c in properties_list if c.startswith("convergence")]) == len(DEFAULT_CONVERGENCE_PROPERTIES_LIST)
return len([c for c in properties_list if c.startswith("convergence")]) == len(
DEFAULT_CONVERGENCE_PROPERTIES_LIST
)


def guess_is_measure(properties_list: list) -> bool:
"""Check if it is a measure test"""

return any([c for c in properties_list if c.startswith("measure")])


def guess_is_ph(properties_list: list) -> bool:
"""Check if it has a measure test"""

Expand Down Expand Up @@ -175,15 +182,19 @@ def launch(
is_ph = guess_is_ph(properties_list)

if is_ph and not ph_code:
echo.echo_critical("ph_code must be provided since we run on it for phonon frequencies.")
echo.echo_critical(
"ph_code must be provided since we run on it for phonon frequencies."
)

if is_convergence and len(configuration) > 1:
echo.echo_critical(
"Only one configuration is allowed for convergence workflow."
)

if is_measure and not is_full_convergence:
echo.echo_warning("Full convergence tests are not run, so we use maximum cutoffs for transferability verification.")
echo.echo_warning(
"Full convergence tests are not run, so we use maximum cutoffs for transferability verification."
)

# Load the curent AiiDA profile and log to user
_profile = aiida.load_profile()
Expand Down Expand Up @@ -211,7 +222,9 @@ def launch(
clean_workdir=clean_workdir,
)

builder.metadata.label = f"({protocol} at {pw_code.computer.label} - {conf_label}) {pseudo.stem}"
builder.metadata.label = (
f"({protocol} at {pw_code.computer.label} - {conf_label}) {pseudo.stem}"
)
builder.metadata.description = f"""Calculation is run on protocol: {protocol}; on {pw_code.computer.label}; on configuration {conf_label}; on pseudo {pseudo.stem}."""

builder.pw_code = pw_code
Expand Down
1 change: 0 additions & 1 deletion src/aiida_sssp_workflow/protocol/criteria.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,3 @@ v2024.1001:
bounds: [0.0, 10] # when error eta_c < 20 meV
eps: 1.0e-3
unit: meV/atom

1 change: 1 addition & 0 deletions src/aiida_sssp_workflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_default_mpi_options(
"withmpi": with_mpi,
}


def serialize_data(data):
from aiida.orm import (
AbstractCode,
Expand Down
23 changes: 12 additions & 11 deletions src/aiida_sssp_workflow/utils/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aiida_sssp_workflow.utils.pseudo import DualType, get_dual_type


def get_protocol(category: str, name: str | None=None):
def get_protocol(category: str, name: str | None = None):
"""Load and read protocol from faml file to a verbose dict
if name not set, return whole protocol."""
import_path = resources.path("aiida_sssp_workflow.protocol", f"{category}.yml")
Expand All @@ -17,24 +17,25 @@ def get_protocol(category: str, name: str | None=None):
else:
return protocol_dict

def generate_cutoff_list(protocol_name: str, element: str, pp_type: str) -> List[Tuple[int, int]]:
"""From the control protocol name, get the cutoff list
"""

def generate_cutoff_list(
protocol_name: str, element: str, pp_type: str
) -> List[Tuple[int, int]]:
"""From the control protocol name, get the cutoff list"""
match get_dual_type(pp_type, element):
case DualType.NC:
dual_type = 'nc_dual_scan'
dual_type = "nc_dual_scan"
case DualType.AUGLOW:
dual_type = 'nonnc_dual_scan'
dual_type = "nonnc_dual_scan"
case DualType.AUGHIGH:
dual_type = 'nonnc_high_dual_scan'
dual_type = "nonnc_high_dual_scan"

dual_scan_list = get_protocol('control', protocol_name)[dual_type]
dual_scan_list = get_protocol("control", protocol_name)[dual_type]
if len(dual_scan_list) > 0:
max_dual = int(max(dual_scan_list))
else:
max_dual = 8

ecutwfc_list = get_protocol('control', protocol_name)['wfc_scan']

return [(e, e*max_dual) for e in ecutwfc_list]
ecutwfc_list = get_protocol("control", protocol_name)["wfc_scan"]

return [(e, e * max_dual) for e in ecutwfc_list]
15 changes: 9 additions & 6 deletions src/aiida_sssp_workflow/utils/pseudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,21 @@ class PseudoInfo(BaseModel):
# source_lib: str
# ...


class DualType(Enum):
NC = "nc"
AUGLOW = "charge augmentation low"
AUGHIGH = "charge augmentation high"


def get_dual_type(pp_type: str, element: str) -> DualType:
if element in HIGH_DUAL_ELEMENTS and pp_type != 'nc':
return DualType.AUGHIGH
elif pp_type == 'nc':
return DualType.NC
else:
return DualType.AUGLOW
if element in HIGH_DUAL_ELEMENTS and pp_type != "nc":
return DualType.AUGHIGH
elif pp_type == "nc":
return DualType.NC
else:
return DualType.AUGLOW


def extract_pseudo_info(pseudo_text: str) -> PseudoInfo:
"""Giving a pseudo, extract the pseudo info and return as a `PseudoInfo` object"""
Expand Down
2 changes: 0 additions & 2 deletions src/aiida_sssp_workflow/workflows/convergence/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def configuration(self):

return self.ctx.configuration


@property
def pseudos(self):
"""Syntax sugar for self.ctx.pseudos"""
Expand Down Expand Up @@ -279,7 +278,6 @@ def get_builder(
if ret := is_valid_cutoff_list(cutoff_list):
raise ValueError(ret)


builder.cutoff_list = orm.List(list=cutoff_list)
builder.clean_workdir = orm.Bool(clean_workdir)

Expand Down
24 changes: 11 additions & 13 deletions src/aiida_sssp_workflow/workflows/convergence/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def compute_xy(
report = ConvergenceReport.construct(**report_dict)

reference_node = orm.load_node(report.reference.uuid)
band_structure_r: orm.BandsData = reference_node.outputs.bands.band_structure
band_structure_r: orm.BandsData = reference_node.outputs.bands.band_structure
band_parameters_r: orm.Dict = reference_node.outputs.bands.band_parameters

bandsdata_r = {
Expand All @@ -173,7 +173,7 @@ def compute_xy(
}

# smearing width is from degauss
smearing = reference_node.inputs.bands.pw.parameters.get_dict()['SYSTEM']['degauss']
smearing = reference_node.inputs.bands.pw.parameters.get_dict()["SYSTEM"]["degauss"]
fermi_shift = reference_node.inputs.fermi_shift.value

# always do smearing on high bands and not include the spin since we didn't turn on the spin for all
Expand All @@ -188,15 +188,14 @@ def compute_xy(
if node_point.exit_status != 0:
# TODO: log to a warning file for where the node is not finished_okay
continue

x = node_point.wavefunction_cutoff
xs.append(x)

node = orm.load_node(node_point.uuid)
node = orm.load_node(node_point.uuid)
band_structure_p: orm.BandsData = node.outputs.bands.band_structure
band_parameters_p: orm.Dict = node.outputs.bands.band_parameters


# The raw implementation of `get_bands_distance` is in `aiida_sssp_workflow/calculations/bands_distance.py`
bandsdata_p = {
"number_of_electrons": band_parameters_p["number_of_electrons"],
Expand All @@ -221,14 +220,13 @@ def compute_xy(

# eta_c is the y, others are write into as metadata
ys.append(eta_c)


return {
'x': xs,
'y': ys,
'metadata': {
'shift_c': shift_c,
'max_diff_c': max_diff_c,
'unit': unit,
}
"x": xs,
"y": ys,
"metadata": {
"shift_c": shift_c,
"max_diff_c": max_diff_c,
"unit": unit,
},
}
18 changes: 9 additions & 9 deletions src/aiida_sssp_workflow/workflows/convergence/cohesive_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def prepare_evaluate_builder(self, ecutwfc, ecutrho) -> ProcessBuilder:

return builder


def compute_xy(
node: orm.Node,
) -> dict[str, Any]:
Expand All @@ -181,29 +182,28 @@ def compute_xy(

reference_node = orm.load_node(report.reference.uuid)
output_parameters_r: orm.Dict = reference_node.outputs.output_parameters
y_ref = output_parameters_r['cohesive_energy_per_atom']
y_ref = output_parameters_r["cohesive_energy_per_atom"]

xs = []
ys = []
for node_point in report.convergence_list:
if node_point.exit_status != 0:
# TODO: log to a warning file for where the node is not finished_okay
continue

x = node_point.wavefunction_cutoff
xs.append(x)

node = orm.load_node(node_point.uuid)
output_parameters_p: orm.Dict = node.outputs.output_parameters

y = abs((output_parameters_p['cohesive_energy_per_atom'] - y_ref) / y_ref) * 100
y = abs((output_parameters_p["cohesive_energy_per_atom"] - y_ref) / y_ref) * 100
ys.append(y)

return {
'x': xs,
'y': ys,
'metadata': {
'unit': '%',
}
"x": xs,
"y": ys,
"metadata": {
"unit": "%",
},
}

20 changes: 10 additions & 10 deletions src/aiida_sssp_workflow/workflows/convergence/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,43 +130,43 @@ def prepare_evaluate_builder(self, ecutwfc, ecutrho) -> ProcessBuilder:

return builder


def compute_xy(
report: ConvergenceReport,
) -> dict[str, Any]:
"""From report calculate the xy data, xs are cutoffs and ys are band distance from reference"""
reference_node = orm.load_node(report.reference.uuid)
output_parameters_r: orm.Dict = reference_node.outputs.output_parameters
ref_V0, ref_B0, ref_B1 = output_parameters_r['birch_murnaghan_results']


ref_V0, ref_B0, ref_B1 = output_parameters_r["birch_murnaghan_results"]

xs = []
ys = []
for node_point in report.convergence_list:
if node_point.exit_status != 0:
# TODO: log to a warning file for where the node is not finished_okay
continue

x = node_point.wavefunction_cutoff
xs.append(x)

node = orm.load_node(node_point.uuid)
output_parameters_p: orm.Dict = node.outputs.output_parameters

V0, B0, B1 = output_parameters_p['birch_murnaghan_results']
V0, B0, B1 = output_parameters_p["birch_murnaghan_results"]

y_nu = rel_errors_vec_length(ref_V0, ref_B0, ref_B1, V0, B0, B1)

ys.append(y_nu)

return {
'x': xs,
'y': ys,
'metadata': {
'unit': 'n/a',
}
"x": xs,
"y": ys,
"metadata": {
"unit": "n/a",
},
}


# def compute_xy_epsilon(
# report: ConvergenceReport,
# ) -> dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def prepare_evaluate_builder(self, ecutwfc, ecutrho):

return builder


def compute_xy(
node: orm.Node,
) -> dict[str, Any]:
Expand All @@ -188,7 +189,7 @@ def compute_xy(
if node_point.exit_status != 0:
# TODO: log to a warning file for where the node is not finished_okay
continue

x = node_point.wavefunction_cutoff
xs.append(x)

Expand Down Expand Up @@ -230,10 +231,9 @@ def compute_xy(
ys.append(y)

return {
'x': xs,
'y': ys,
'metadata': {
'unit': '%',
}
"x": xs,
"y": ys,
"metadata": {
"unit": "%",
},
}

Loading

0 comments on commit 25d9ec6

Please sign in to comment.