Skip to content

Commit

Permalink
refactor getdata
Browse files Browse the repository at this point in the history
  • Loading branch information
caic99 committed Nov 27, 2024
1 parent a4a36c1 commit ee9d8f8
Showing 1 changed file with 13 additions and 41 deletions.
54 changes: 13 additions & 41 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,47 +1053,19 @@ def save_model(self, save_path, lr=0.0, step=0) -> None:
checkpoint_files[0].unlink()

def get_data(self, is_train=True, task_key="Default"):
if not self.multi_task:
if is_train:
try:
batch_data = next(iter(self.training_data))
except StopIteration:
# Refresh the status of the dataloader to start from a new epoch
with torch.device("cpu"):
self.training_data = BufferedIterator(
iter(self.training_dataloader)
)
batch_data = next(iter(self.training_data))
else:
if self.validation_data is None:
return {}, {}, {}
try:
batch_data = next(iter(self.validation_data))
except StopIteration:
self.validation_data = BufferedIterator(
iter(self.validation_dataloader)
)
batch_data = next(iter(self.validation_data))
else:
if is_train:
try:
batch_data = next(iter(self.training_data[task_key]))
except StopIteration:
# Refresh the status of the dataloader to start from a new epoch
self.training_data[task_key] = BufferedIterator(
iter(self.training_dataloader[task_key])
)
batch_data = next(iter(self.training_data[task_key]))
else:
if self.validation_data[task_key] is None:
return {}, {}, {}
try:
batch_data = next(iter(self.validation_data[task_key]))
except StopIteration:
self.validation_data[task_key] = BufferedIterator(
iter(self.validation_dataloader[task_key])
)
batch_data = next(iter(self.validation_data[task_key]))
data, dataloader = (self.training_data, self.training_dataloader) \
if is_train else (self.validation_data, self.validation_dataloader)
if data is None and not is_train:
return {}, {}, {}
if self.multi_task:
data=data[task_key]
dataloader=dataloader[task_key]
try:
batch_data = next(iter(data))
except StopIteration:
# Refresh the status of the dataloader to start from a new epoch
data = BufferedIterator(iter(dataloader))
batch_data = next(iter(data))

for key in batch_data.keys():
if key == "sid" or key == "fid" or key == "box" or "find_" in key:
Expand Down

0 comments on commit ee9d8f8

Please sign in to comment.