From d5596bf9130c974e87a2e4d018be1153c437a50a Mon Sep 17 00:00:00 2001 From: SumGuo Date: Fri, 10 Jan 2025 14:42:29 +0800 Subject: [PATCH] improve warning readable --- deepmd/pt/utils/dataset.py | 7 ++++++- deepmd/pt/utils/stat.py | 11 ++++++++--- deepmd/utils/data.py | 14 ++++++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 267619b69d..e1deaaed8b 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -62,9 +62,14 @@ def get_frame_index_for_elements(self): element_counts = defaultdict(lambda: {"frames": 0, "indices": []}) set_files = self._data_system.dirs base_offset = 0 + global_type_name = {} for set_file in set_files: element_data = self._data_system._load_type_mix(set_file) unique_elements = np.unique(element_data) + type_name = self._data_system.build_reidx_to_name_map(element_data,set_file) + for new_idx, elem_name in type_name.items(): + if new_idx not in global_type_name: + global_type_name[new_idx] = elem_name for elem in unique_elements: frames_with_elem = np.any(element_data == elem, axis=1) row_indices = np.where(frames_with_elem)[0] @@ -73,7 +78,7 @@ def get_frame_index_for_elements(self): element_counts[elem]["indices"].extend(row_indices_global.tolist()) base_offset += element_data.shape[0] element_counts = dict(element_counts) - return element_counts + return element_counts, global_type_name def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None: """Add data requirement for this data system.""" diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 36a77c198c..7c0a8aa265 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -60,6 +60,7 @@ def make_stat_input( log.info(f"Packing data for statistics from {len(datasets)} systems") total_element_types = set() global_element_counts = {} + global_type_name = {} collect_ele = defaultdict(int) if datasets[0].mixed_type: if enable_element_completion: @@ -160,7 +161,10 @@ def finalize_stats(sys_stat): # get frame index if datasets[0].mixed_type and enable_element_completion: - element_counts = dataset.get_frame_index_for_elements() + element_counts, type_map = dataset.get_frame_index_for_elements() + for new_idx, elem_name in type_name.items(): + if new_idx not in global_type_name: + global_type_name[new_idx] = elem_name for elem, data in element_counts.items(): indices = data["indices"] count = data["frames"] @@ -195,10 +199,11 @@ def finalize_stats(sys_stat): if datasets[0].mixed_type and enable_element_completion: for elem, data in global_element_counts.items(): indices_count = data["count"] + element_name = global_type_name.get(elem, f"") if indices_count < min_frames_per_element_forstat: log.warning( - f"The number of frames in your datasets with element {elem} is {indices_count}, " - f"which is less than the required {min_frames_per_element_forstat}" + f"The number of frames in your datasets with element {element_name} is {indices_count}, " + f"which is less than the set {min_frames_per_element_forstat}" ) collect_elements = collect_ele.keys() missing_elements = total_element_types - collect_elements diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index fa01452bac..f5b9397df8 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -706,6 +706,20 @@ def _load_type_mix(self, set_name: DPPath): real_type = atom_type_mix_ return real_type + def build_reidx_to_name_map(self,typemix, set_name: DPPath): + type_map = self.type_map + type_path = set_name / "real_atom_types.npy" + real_type = type_path.load_numpy().astype(np.int32).reshape([-1, self.natoms]) + type_map_array = np.array(type_map, dtype=object) + reidx_to_name = {} + N, M = real_type.shape + for i in range(N): + for j in range(M): + old_val = int(real_type[i, j]) + new_val = int(typemix[i, j]) + reidx_to_name[new_val] = type_map_array[old_val] + return reidx_to_name + def _make_idx_map(self, atom_type): natoms = atom_type.shape[0] idx = np.arange(natoms, dtype=np.int64)