Skip to content

Commit

Permalink
improve warning readable
Browse files Browse the repository at this point in the history
  • Loading branch information
SumGuo-88 committed Jan 10, 2025
1 parent 8763165 commit d5596bf
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
7 changes: 6 additions & 1 deletion deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,14 @@ def get_frame_index_for_elements(self):
element_counts = defaultdict(lambda: {"frames": 0, "indices": []})
set_files = self._data_system.dirs
base_offset = 0
global_type_name = {}
for set_file in set_files:
element_data = self._data_system._load_type_mix(set_file)
unique_elements = np.unique(element_data)
type_name = self._data_system.build_reidx_to_name_map(element_data,set_file)
for new_idx, elem_name in type_name.items():
if new_idx not in global_type_name:
global_type_name[new_idx] = elem_name
for elem in unique_elements:
frames_with_elem = np.any(element_data == elem, axis=1)
row_indices = np.where(frames_with_elem)[0]
Expand All @@ -73,7 +78,7 @@ def get_frame_index_for_elements(self):
element_counts[elem]["indices"].extend(row_indices_global.tolist())
base_offset += element_data.shape[0]
element_counts = dict(element_counts)
return element_counts
return element_counts, global_type_name

def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None:
"""Add data requirement for this data system."""
Expand Down
11 changes: 8 additions & 3 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def make_stat_input(
log.info(f"Packing data for statistics from {len(datasets)} systems")
total_element_types = set()
global_element_counts = {}
global_type_name = {}
collect_ele = defaultdict(int)
if datasets[0].mixed_type:
if enable_element_completion:
Expand Down Expand Up @@ -160,7 +161,10 @@ def finalize_stats(sys_stat):

# get frame index
if datasets[0].mixed_type and enable_element_completion:
element_counts = dataset.get_frame_index_for_elements()
element_counts, type_map = dataset.get_frame_index_for_elements()
for new_idx, elem_name in type_name.items():
if new_idx not in global_type_name:
global_type_name[new_idx] = elem_name
for elem, data in element_counts.items():
indices = data["indices"]
count = data["frames"]
Expand Down Expand Up @@ -195,10 +199,11 @@ def finalize_stats(sys_stat):
if datasets[0].mixed_type and enable_element_completion:
for elem, data in global_element_counts.items():
indices_count = data["count"]
element_name = global_type_name.get(elem, f"<unknown-{elem}>")
if indices_count < min_frames_per_element_forstat:
log.warning(
f"The number of frames in your datasets with element {elem} is {indices_count}, "
f"which is less than the required {min_frames_per_element_forstat}"
f"The number of frames in your datasets with element {element_name} is {indices_count}, "
f"which is less than the set {min_frames_per_element_forstat}"
)
collect_elements = collect_ele.keys()
missing_elements = total_element_types - collect_elements
Expand Down
14 changes: 14 additions & 0 deletions deepmd/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,20 @@ def _load_type_mix(self, set_name: DPPath):
real_type = atom_type_mix_
return real_type

def build_reidx_to_name_map(self,typemix, set_name: DPPath):
type_map = self.type_map
type_path = set_name / "real_atom_types.npy"
real_type = type_path.load_numpy().astype(np.int32).reshape([-1, self.natoms])
type_map_array = np.array(type_map, dtype=object)
reidx_to_name = {}
N, M = real_type.shape
for i in range(N):
for j in range(M):
old_val = int(real_type[i, j])
new_val = int(typemix[i, j])
reidx_to_name[new_val] = type_map_array[old_val]
return reidx_to_name

def _make_idx_map(self, atom_type):
natoms = atom_type.shape[0]
idx = np.arange(natoms, dtype=np.int64)
Expand Down

0 comments on commit d5596bf

Please sign in to comment.