diff --git a/aiida_yambo_wannier90/workflows/__init__.py b/aiida_yambo_wannier90/workflows/__init__.py index 48f9080..4d3990b 100644 --- a/aiida_yambo_wannier90/workflows/__init__.py +++ b/aiida_yambo_wannier90/workflows/__init__.py @@ -26,6 +26,7 @@ from aiida_wannier90_workflows.utils.kpoints import ( get_explicit_kpoints, get_mesh_from_kpoints, + get_path_from_kpoints ) from aiida_wannier90_workflows.utils.workflows.builder.setter import set_kpoints from aiida_wannier90_workflows.workflows import ( @@ -626,9 +627,7 @@ def setup(self) -> None: # pylint: disable=inconsistent-return-statements """Initialize context variables.""" self.ctx.current_structure = self.inputs.structure - - if "bands_kpoints" in self.inputs: - self.ctx.bands_kpoints = self.inputs.bands_kpoints + # Converged mesh from YamboConvergence self.ctx.kpoints_gw_conv = None @@ -676,7 +675,13 @@ def setup(self) -> None: # pylint: disable=inconsistent-return-statements def should_run_seekpath(self): """Run seekpath if the `inputs.bands_kpoints` is not provided.""" - return "bands_kpoints" not in self.inputs + if "bands_kpoints" in self.inputs: + self.ctx.current_kpoint_path = get_path_from_kpoints( + self.inputs["bands_kpoints"] + ) + return False + else: + return True def run_seekpath(self): """Run the structure through SeeKpath to get the primitive and normalized structure.""" @@ -692,7 +697,11 @@ def run_seekpath(self): self.ctx.current_structure = result["primitive_structure"] - self.ctx.current_bands_kpoints = result["explicit_kpoints"] + # Add `kpoint_path` for Wannier bands + self.ctx.current_kpoint_path = get_path_from_kpoints( + result["explicit_kpoints"] + ) + structure_formula = self.inputs.structure.get_formula() primitive_structure_formula = result["primitive_structure"].get_formula() @@ -1056,11 +1065,12 @@ def prepare_wannier90_pp_inputs(self) -> AttributeDict: inputs.wannier90.structure = self.ctx.current_structure - #params = inputs.wannier90.parameters.get_dict() - #params["bands_plot"] = False - #inputs.wannier90.parameters = orm.Dict(params) + params = inputs.wannier90.parameters.get_dict() + params["bands_plot"] = False + inputs.wannier90.parameters = orm.Dict(params) - inputs.wannier90.bands_kpoints = self.ctx.current_bands_kpoints + if self.ctx.current_kpoint_path: + inputs.wannier90.kpoint_path = self.ctx.current_kpoint_path # Use commensurate kmesh if self.ctx.kpoints_w90_input != self.ctx.kpoints_w90: @@ -1172,7 +1182,8 @@ def prepare_wannier90_inputs(self) -> AttributeDict: ) inputs.structure = self.ctx.current_structure - inputs.bands_kpoints = self.ctx.current_bands_kpoints + if self.ctx.current_kpoint_path: + inputs.wannier90.wannier90.kpoint_path = self.ctx.current_kpoint_path # Use commensurate kmesh if self.ctx.kpoints_w90_input != self.ctx.kpoints_w90: @@ -1258,7 +1269,8 @@ def prepare_wannier90_qp_inputs(self) -> AttributeDict: ) inputs.wannier90.structure = self.ctx.current_structure - inputs.wannier90.bands_kpoints = self.ctx.current_bands_kpoints + if self.ctx.current_kpoint_path: + inputs.kpoint_path = self.ctx.current_kpoint_path if self.ctx.kpoints_w90_input != self.ctx.kpoints_w90: set_kpoints( diff --git a/examples/example_01.py b/examples/example_01.py index 0920d17..a252fa8 100755 --- a/examples/example_01.py +++ b/examples/example_01.py @@ -12,7 +12,7 @@ from aiida_wannier90_workflows.cli.params import RUN from aiida_wannier90_workflows.utils.workflows.builder.serializer import print_builder from aiida_wannier90_workflows.utils.kpoints import get_explicit_kpoints_from_mesh -from aiida_wannier90_workflows.utils.workflows.builder.setter import set_parallelization, set_num_bands, set_kpoints +from aiida_wannier90_workflows.utils.workflows.builder.setter import set_parallelization, set_num_bands from aiida_wannier90_workflows.utils.workflows.builder.submit import submit_and_add_group from aiida_wannier90_workflows.common.types import WannierProjectionType from aiida_wannier90_workflows.workflows import Wannier90BandsWorkChain