From f3e6e68410e2b5fa0dde8f1403994884e00f836e Mon Sep 17 00:00:00 2001 From: Bruno Ploumhans <13494793+Technici4n@users.noreply.github.com> Date: Fri, 8 Nov 2024 21:42:54 +0100 Subject: [PATCH] Simplify (#12) --- src/aiida_dftk/calculations.py | 109 ++++++++++++++----------------- src/aiida_dftk/parsers.py | 18 +++-- src/aiida_dftk/workflows/base.py | 19 ++---- 3 files changed, 63 insertions(+), 83 deletions(-) diff --git a/src/aiida_dftk/calculations.py b/src/aiida_dftk/calculations.py index f5b0b66..c335d91 100644 --- a/src/aiida_dftk/calculations.py +++ b/src/aiida_dftk/calculations.py @@ -4,6 +4,7 @@ import os import json import typing as ty +from pathlib import Path from aiida import orm from aiida.common import datastructures, exceptions @@ -15,11 +16,9 @@ class DftkCalculation(CalcJob): """`CalcJob` implementation for DFTK.""" - _DEFAULT_PREFIX = 'DFTK' - _DEFAULT_INPUT_EXTENSION = 'json' - _DEFAULT_STDOUT_EXTENSION = 'txt' - _DEFAULT_SCFRES_SUMMARY_NAME = 'self_consistent_field.json' - _SUPPORTED_POSTSCF = ['compute_forces_cart', 'compute_stresses_cart','compute_bands'] + SCFRES_SUMMARY_NAME = 'self_consistent_field.json' + # TODO: don't limit postscf + _SUPPORTED_POSTSCF = ['compute_forces_cart', 'compute_stresses_cart', 'compute_bands'] _PSEUDO_SUBFOLDER = './pseudo/' _MIN_OUTPUT_BUFFER_TIME = 60 @@ -37,24 +36,25 @@ def define(cls, spec): """Define the process specification.""" super().define(spec) # Inputs - spec.input('metadata.options.prefix', valid_type=str, default=cls._DEFAULT_PREFIX) - spec.input('metadata.options.stdout_extension', valid_type=str, default=cls._DEFAULT_STDOUT_EXTENSION) - spec.input('metadata.options.withmpi', valid_type=bool, default=True) - spec.input('metadata.options.max_wallclock_seconds', valid_type=int, default=1800) - spec.input('structure', valid_type=orm.StructureData, help='structure') spec.input_namespace('pseudos', valid_type=UpfData, help='The pseudopotentials.', dynamic=True) spec.input('kpoints', valid_type=orm.KpointsData, help='kpoint mesh or kpoint path') spec.input('parameters', valid_type=orm.Dict, help='input parameters') - spec.input('settings', valid_type=orm.Dict, required=False, help='Various special settings.') spec.input('parent_folder', valid_type=orm.RemoteData, required=False, help='A remote folder used for restarts.') options = spec.inputs['metadata']['options'] + options['parser_name'].default = 'dftk' + options['input_filename'].default = f'run_dftk.json' + options['max_wallclock_seconds'].default = 1800 + + # TODO: Why is this here? options['resources'].default = {'num_machines': 1, 'num_mpiprocs_per_machine': 1} - options['input_filename'].default = f'{cls._DEFAULT_PREFIX}.{cls._DEFAULT_INPUT_EXTENSION}' + options['withmpi'].default = True # Exit codes + # TODO: Log file should be removed in favor of using stdout. Needs a change in AiidaDFTK.jl. + # TODO: Code 100 is already used in the super class! spec.exit_code(100, 'ERROR_MISSING_LOG_FILE', message='The output file containing DFTK logs is missing.') spec.exit_code(101, 'ERROR_MISSING_SCFRES_FILE', message='The output file containing SCF results is missing.') spec.exit_code(102, 'ERROR_MISSING_FORCES_FILE', message='The output file containing forces is missing.') @@ -67,7 +67,9 @@ def define(cls, spec): # Outputs spec.output('output_parameters', valid_type=orm.Dict, help='output parameters') + # TODO: doesn't seem to be used? spec.output('output_structure', valid_type=orm.Dict, required=False, help='output structure') + # TODO: doesn't seem to be used? spec.output( 'output_kpoints', valid_type=orm.KpointsData, required=False, help='kpoints array, if generated by DFTK' ) @@ -82,7 +84,7 @@ def define(cls, spec): spec.default_output_node = 'output_parameters' - def validate_options(self): + def _validate_options(self): """Validate the options input. Check that the wihmpi option is set to True if the number of mpiprocs is greater than 1. @@ -90,13 +92,14 @@ def validate_options(self): """ options = self.inputs.metadata.options if options.withmpi is False and options.resources.get('num_mpiprocs_per_machine', 1) > 1: + # TODO: does aiida not already check this? raise exceptions.InputValidationError('MPI is required when num_mpiprocs_per_machine > 1.') if options.max_wallclock_seconds < self._MIN_OUTPUT_BUFFER_TIME: raise exceptions.InputValidationError( f'max_wallclock_seconds must be greater than {self._MIN_OUTPUT_BUFFER_TIME}.' ) - def validate_inputs(self): + def _validate_inputs(self): """Validate input parameters. Check that the post-SCF function(s) are supported. @@ -107,21 +110,20 @@ def validate_inputs(self): if postscf['$function'] not in self._SUPPORTED_POSTSCF: raise exceptions.InputValidationError(f"Unsupported postscf function: {postscf['$function']}") - def validate_pseudos(self): - """Valdiate the pseudopotentials. + def _validate_pseudos(self): + """Validate the pseudopotentials. Check that there is a one-to-one map of kinds in the structure to pseudopotentials. """ - kinds = [kind.name for kind in self.inputs.structure.kinds] - if set(kinds) != set(self.inputs.pseudos.keys()): - pseudos_str = ', '.join(list(self.inputs.pseudos.keys())) - kinds_str = ', '.join(list(kinds)) + kinds = set(kind.name for kind in self.inputs.structure.kinds) + pseudos = set(self.inputs.pseudos.keys()) + if kinds != pseudos: raise exceptions.InputValidationError( 'Mismatch between the defined pseudos and the list of kinds of the structure.\n' - f'Pseudos: {pseudos_str};\nKinds:{kinds_str}' + f'Pseudos: {pseudos};\nKinds:{kinds}' ) - def validate_kpoints(self): + def _validate_kpoints(self): """Validate the k-points intput. Check that the input k-points provide a k-points mesh. @@ -159,13 +161,11 @@ def _generate_inputdata( local_copy_pseudo_list.append((pseudo.uuid, pseudo.filename, f'{self._PSEUDO_SUBFOLDER}{pseudo.filename}')) data['basis_kwargs']['kgrid'], data['basis_kwargs']['kshift'] = kpoints.get_kpoints_mesh() - # set the maxtime for the SCF cycle - # if max_wallclock_seconds is smaller than 600 seconds, set the maxtime as max_wallclock_seconds - MIN_OUTPUT_BUFFER_TIME - # else set the maxtime as int(0.95 * max_wallclock_seconds) - if self.inputs.metadata.options.max_wallclock_seconds < self._MIN_OUTPUT_BUFFER_TIME * 10: - maxtime = self.inputs.metadata.options.max_wallclock_seconds - self._MIN_OUTPUT_BUFFER_TIME - else: - maxtime = int(0.9 * self.inputs.metadata.options.max_wallclock_seconds) + # set the maxtime for the SCF cycle, with a margin of _MIN_OUTPUT_BUFFER_TIME and 10%, whichever leads to a larger margin + maxtime = min( + self.inputs.metadata.options.max_wallclock_seconds - self._MIN_OUTPUT_BUFFER_TIME, + 0.9 * self.inputs.metadata.options.max_wallclock_seconds, + ) data['scf']['maxtime'] = maxtime DftkCalculation._merge_dicts(data, parameters.get_dict()) @@ -178,6 +178,11 @@ def _generate_cmdline_params(self) -> ty.List[str]: cmd_params.extend(['-e', 'using AiidaDFTK; AiidaDFTK.run()', self.metadata.options.input_filename]) return cmd_params + @staticmethod + def get_log_file(input_filename: str) -> str: + """Gets the name of the log file based on the name of the input file.""" + return Path(input_filename).stem + '.log' + def _generate_retrieve_list(self, parameters: orm.Dict) -> list: """Generate the list of files to retrieve based on the type of calculation requested in the input parameters. @@ -190,10 +195,9 @@ def _generate_retrieve_list(self, parameters: orm.Dict) -> list: f"{item['$function']}.json" if item['$function'] == 'compute_bands' else f"{item['$function']}.hdf5" for item in parameters['postscf'] ] - retrieve_list.append(f'{self._DEFAULT_PREFIX}.log') + retrieve_list.append(DftkCalculation.get_log_file(self.inputs.metadata.options.input_filename)) retrieve_list.append('timings.json') - retrieve_list.append(f'{self._DEFAULT_PREFIX}.{self._DEFAULT_STDOUT_EXTENSION}') - retrieve_list.append(f'{self._DEFAULT_SCFRES_SUMMARY_NAME}') + retrieve_list.append(f'{self.SCFRES_SUMMARY_NAME}') return retrieve_list def prepare_for_submission(self, folder): @@ -203,22 +207,18 @@ def prepare_for_submission(self, folder): the calculation. :return: `aiida.common.datastructures.CalcInfo` instance """ - # Process the `settings`` so that capitalization isn't an issue - settings = self.inputs.settings.get_dict() - - self.validate_options() - self.validate_inputs() - self.validate_pseudos() - self.validate_kpoints() + self._validate_options() + self._validate_inputs() + self._validate_pseudos() + self._validate_kpoints() # Create lists which specify files to copy and symlink remote_copy_list = [] remote_symlink_list = [] # Generate the input file content - arguments = [self.inputs.parameters, self.inputs.structure, self.inputs.pseudos, self.inputs.kpoints] - input_filecontent, local_copy_list = self._generate_inputdata(*arguments) + input_filecontent, local_copy_list = self._generate_inputdata(self.inputs.parameters, self.inputs.structure, self.inputs.pseudos, self.inputs.kpoints) # write input file input_filename = folder.get_abs_path(self.metadata.options.input_filename) @@ -227,25 +227,17 @@ def prepare_for_submission(self, folder): # List the files (scfres.jld2) to copy or symlink in the case of a restart if 'parent_folder' in self.inputs: - # Symlink by default if on the same computer, otherwise copy by default + # Symlink if on the same computer, otherwise copy same_computer = self.inputs.code.computer.uuid == self.inputs.parent_folder.computer.uuid - if settings.pop('PARENT_FOLDER_SYMLINK', same_computer): - remote_symlink_list.append( - ( - self.inputs.parent_folder.computer.uuid, - os.path.join(self.inputs.parent_folder.get_remote_path(), self.inputs.parameters['scf']['checkpointfile']), - self.inputs.parameters['scf']['checkpointfile'] - ) - ) - + checkpointfile_info = ( + self.inputs.parent_folder.computer.uuid, + os.path.join(self.inputs.parent_folder.get_remote_path(), self.inputs.parameters['scf']['checkpointfile']), + self.inputs.parameters['scf']['checkpointfile'] + ) + if same_computer: + remote_symlink_list.append(checkpointfile_info) else: - remote_copy_list.append( - ( - self.inputs.parent_folder.computer.uuid, - os.path.join(self.inputs.parent_folder.get_remote_path(), self.inputs.parameters['scf']['checkpointfile']), - self.inputs.parameters['scf']['checkpointfile'] - ) - ) + remote_copy_list.append(checkpointfile_info) # prepare command line parameters cmdline_params = self._generate_cmdline_params() @@ -257,7 +249,6 @@ def prepare_for_submission(self, folder): codeinfo = datastructures.CodeInfo() codeinfo.code_uuid = self.inputs.code.uuid codeinfo.cmdline_params = cmdline_params - codeinfo.stdout_name = f'{self._DEFAULT_PREFIX}.{self._DEFAULT_STDOUT_EXTENSION}' # Set up the `CalcInfo` so AiiDA knows what to do with everything calcinfo = datastructures.CalcInfo() diff --git a/src/aiida_dftk/parsers.py b/src/aiida_dftk/parsers.py index 356b3d2..470533d 100644 --- a/src/aiida_dftk/parsers.py +++ b/src/aiida_dftk/parsers.py @@ -23,9 +23,7 @@ class DftkParser(Parser): """`Parser` implementation for DFTK.""" - # TODO: I don't like this! - _DEFAULT_SCFRES_SUMMARY_NAME = 'self_consistent_field.json' - _DEFAULT_LOG_FILE_NAME = 'DFTK.log' + # TODO: DEFAULT_ prefix should be removed. I don't think that these names can be changed. _DEFAULT_ENERGY_UNIT = 'hartree' _DEFAULT_FORCE_FUNCNAME = 'compute_forces_cart' _DEFAULT_FORCE_UNIT = 'hartree/bohr' @@ -36,17 +34,17 @@ class DftkParser(Parser): def parse(self, **kwargs): """Parse DFTK output files.""" - # TODO: log recovery doesn't even seem to work? :( - if self._DEFAULT_LOG_FILE_NAME not in self.retrieved.base.repository.list_object_names(): + log_file_name = DftkCalculation.get_log_file(self.node.get_options()["input_filename"]) + if log_file_name not in self.retrieved.base.repository.list_object_names(): return self.exit_codes.ERROR_MISSING_LOG_FILE + # TODO: how to make this log available? This unfortunately doesn't output to the process report. # TODO: maybe DFTK could log in a way that allows us to map its log levels to aiida's - self.logger.info(self.retrieved.base.repository.get_object_content(self._DEFAULT_LOG_FILE_NAME)) + self.logger.info(self.retrieved.base.repository.get_object_content(log_file_name)) - # TODO: double check this if # if ran_out_of_walltime (terminated illy) if self.node.exit_status == DftkCalculation.exit_codes.ERROR_SCHEDULER_OUT_OF_WALLTIME.status: - # if _DEFAULT_SCFRES_SUMMARY_NAME is not in the list self.retrieved.list_object_names(), SCF terminated illy - if self._DEFAULT_SCFRES_SUMMARY_NAME not in self.retrieved.list_object_names(): + # if SCF summary file is not in the list of retrieved files, SCF terminated illy + if DftkCalculation.SCFRES_SUMMARY_NAME not in self.retrieved.list_object_names(): return self.exit_codes.ERROR_SCF_OUT_OF_WALLTIME # POSTSCF terminated illy else: @@ -55,7 +53,7 @@ def parse(self, **kwargs): # Check retrieve list to know which files the calculation is expected to have produced. try: self._parse_optional_result( - self._DEFAULT_SCFRES_SUMMARY_NAME, + DftkCalculation.SCFRES_SUMMARY_NAME, self.exit_codes.ERROR_MISSING_SCFRES_FILE, self._parse_output_parameters, ) diff --git a/src/aiida_dftk/workflows/base.py b/src/aiida_dftk/workflows/base.py index 53a9611..8c591c3 100644 --- a/src/aiida_dftk/workflows/base.py +++ b/src/aiida_dftk/workflows/base.py @@ -38,7 +38,6 @@ def define(cls, spec): spec.outline( cls.setup, - cls.validate_parameters, cls.validate_kpoints, cls.validate_pseudos, cls.validate_resources, @@ -71,16 +70,7 @@ def setup(self): self.ctx.restart_calc = None self.ctx.inputs = AttributeDict(self.exposed_inputs(DftkCalculation, 'dftk')) - def validate_parameters(self): - """Validate inputs that might depend on each other and cannot be validated by the spec. - - Also define dictionary `inputs` in the context, that will contain the inputs for the calculation that will be - launched in the `run_calculation` step. - """ - #super().setup() - self.ctx.inputs.parameters = self.ctx.inputs.parameters.get_dict() - self.ctx.inputs.settings = self.ctx.inputs.settings.get_dict() if 'settings' in self.ctx.inputs else {} - + # TODO: We probably want to handle the kpoint distance on the Julia side instead. def validate_kpoints(self): """Validate the inputs related to k-points. @@ -121,6 +111,7 @@ def validate_pseudos(self): return self.exit_codes.ERROR_INVALID_INPUT_PSEUDO_POTENTIALS # pylint: disable=no-member + # TODO: This is weird. Shouldn't aiida already handle this internally? def validate_resources(self): """Validate the inputs related to the resources. @@ -154,8 +145,7 @@ def report_error_handled(self, calculation, action): :param calculation: the failed calculation node :param action: a string message with the action taken """ - arguments = [calculation.process_label, calculation.pk, calculation.exit_status, calculation.exit_message] - self.report('{}<{}> failed with exit status {}: {}'.format(*arguments)) + self.report(f'{calculation.process_label}<{calculation.pk}> failed with exit status {calculation.exit_status}: {calculation.exit_message}') self.report(f'Action taken: {action}') @process_handler(priority=500, exit_codes=[DftkCalculation.exit_codes.ERROR_SCF_CONVERGENCE_NOT_REACHED]) @@ -164,12 +154,13 @@ def handle_scf_convergence_not_reached(self, _): return None # Just as a blueprint, delete after ^ is implemented + # TODO: What exactly is this doing? @process_handler(priority=580, exit_codes=[ DftkCalculation.exit_codes.ERROR_SCF_CONVERGENCE_NOT_REACHED, DftkCalculation.exit_codes.ERROR_POSTSCF_OUT_OF_WALLTIME ]) def handle_recoverable_SCF_unconverged_and_POSTSCF_out_of_walltime_(self, calculation): - """Handle `RROR_SCF_CONVERGENCE_NOT_REACHED` and `ERROR_POSTSCF_OUT_OF_WALLTIME` exit code: calculations shut down neatly and we can simply restart.""" + """Handle `ERROR_SCF_CONVERGENCE_NOT_REACHED` and `ERROR_POSTSCF_OUT_OF_WALLTIME` exit code: calculations shut down neatly and we can simply restart.""" try: self.ctx.inputs.structure = calculation.outputs.output_structure except exceptions.NotExistent: