Skip to content

Commit

Permalink
Split merge_trajectory_data to merge_trajectory_data_unique and merge…
Browse files Browse the repository at this point in the history
…_trajectory_data_non_unique
  • Loading branch information
yakutovicha committed Mar 7, 2024
1 parent db359a0 commit f0acb73
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 12 deletions.
8 changes: 6 additions & 2 deletions aiida_cp2k/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
###############################################################################
"""AiiDA-CP2K utils"""

from .datatype_helpers import merge_trajectory_data
from .datatype_helpers import (
merge_trajectory_data_non_unique,
merge_trajectory_data_unique,
)
from .input_generator import (
Cp2kInput,
add_ext_restart_section,
Expand Down Expand Up @@ -43,5 +46,6 @@
"merge_Dict",
"ot_has_small_bandgap",
"resize_unit_cell",
"merge_trajectory_data",
"merge_trajectory_data_unique",
"merge_trajectory_data_non_unique",
]
43 changes: 37 additions & 6 deletions aiida_cp2k/utils/datatype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,15 +417,12 @@ def write_pseudos(inp, pseudos, folder):
_write_gdt(inp, pseudos, folder, "POTENTIAL_FILE_NAME", "POTENTIAL")


@engine.calcfunction
def merge_trajectory_data(*trajectories, unique_stepids=False):
def _merge_trajectories_into_dictionary(*trajectories, unique_stepids=False):
if len(trajectories) < 0:
return None
final_trajectory = orm.TrajectoryData()
final_trajectory_dict = {}

array_names = trajectories[0].get_arraynames()
symbols = trajectories[0].symbols

for array_name in array_names:
if any(array_name not in traj.get_arraynames() for traj in trajectories):
Expand All @@ -437,10 +434,44 @@ def merge_trajectory_data(*trajectories, unique_stepids=False):
)
final_trajectory_dict[array_name] = merged_array

# If unique_stepids is True, we only keep the unique stepids.
# The other arrays are then also reduced to the unique stepids.
if unique_stepids:
stepids = np.concatenate([traj.get_stepids() for traj in trajectories], axis=0)
final_trajectory_dict["stepids"], unique_indices = np.unique(
stepids, return_index=True
)

for array_name in array_names:
final_trajectory_dict[array_name] = final_trajectory_dict[array_name][
unique_indices
]

return final_trajectory_dict


def _dictionary_to_trajectory(trajectory_dict, symbols):
final_trajectory = orm.TrajectoryData()
final_trajectory.set_trajectory(
symbols=symbols, positions=final_trajectory_dict.pop("positions")
symbols=symbols, positions=trajectory_dict.pop("positions")
)
for array_name, array in final_trajectory_dict.items():
for array_name, array in trajectory_dict.items():
final_trajectory.set_array(array_name, array)

return final_trajectory


@engine.calcfunction
def merge_trajectory_data_unique(*trajectories):
trajectory_dict = _merge_trajectories_into_dictionary(
*trajectories, unique_stepids=True
)
return _dictionary_to_trajectory(trajectory_dict, trajectories[0].symbols)


@engine.calcfunction
def merge_trajectory_data_non_unique(*trajectories):
trajectory_dict = _merge_trajectories_into_dictionary(
*trajectories, unique_stepids=False
)
return _dictionary_to_trajectory(trajectory_dict, trajectories[0].symbols)
2 changes: 1 addition & 1 deletion aiida_cp2k/workchains/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def results(self):
trajectories = self._collect_all_trajetories()
if trajectories:
self.report("Work chain completed successfully, collecting all trajectories")
self.out("output_trajectory", utils.merge_trajectory_data(*trajectories))
self.out("output_trajectory", utils.merge_trajectory_data_unique(*trajectories))

def overwrite_input_structure(self):
if "output_structure" in self.ctx.children[self.ctx.iteration-1].outputs:
Expand Down
9 changes: 6 additions & 3 deletions test/test_datatype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import pytest
from aiida import orm

from aiida_cp2k.utils import merge_trajectory_data
from aiida_cp2k.utils import (
merge_trajectory_data_non_unique,
merge_trajectory_data_unique,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -46,12 +49,12 @@ def get_trajectory(step1=1, step2=20):
unique_elements.extend(range(step_range[0], step_range[1] + 1))
total_lenght_unique = len(set(unique_elements))

merged_trajectory = merge_trajectory_data(*trajectories)
merged_trajectory = merge_trajectory_data_non_unique(*trajectories)
assert (
len(merged_trajectory.get_stepids()) == total_length
), "The merged trajectory has the wrong length."

merged_trajectory_unique = merge_trajectory_data(*trajectories, unique_stepids=True)
merged_trajectory_unique = merge_trajectory_data_unique(*trajectories)
assert (
len(merged_trajectory_unique.get_stepids()) == total_lenght_unique
), "The merged trajectory with unique stepids has the wrong length."

0 comments on commit f0acb73

Please sign in to comment.