Skip to content

Commit

Permalink
cleanup: printing messages
Browse files Browse the repository at this point in the history
  • Loading branch information
nmvrs committed Jun 7, 2024
1 parent 1f2f66a commit de25301
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
4 changes: 2 additions & 2 deletions moai/data/datasets/generic/npz.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
filename: str = "",
):
self.file = load_npz_file(filename)
log.info(f"Loaded an .npz file producing [{list(self.file.keys())}].")
log.info(f"Loaded an .npz file producing {list(self.file.keys())}.")

def __len__(self) -> int:
return len(self.file[toolz.first(self.file)])
Expand All @@ -63,7 +63,7 @@ def __init__(
):
self.file = load_npz_file(filename)
self.length = length
log.info(f"Loaded an .npz file producing [{list(self.file.keys())}].")
log.info(f"Loaded an .npz file producing {list(self.file.keys())}.")

def __len__(self) -> int:
return self.length
Expand Down
5 changes: 4 additions & 1 deletion moai/parameters/initialization/schemes/zero_flow_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ def __init__(
self.keys = keys

def __call__(self, module: torch.nn.Module) -> None:
zeroed_keys = []
for key in self.keys:
try:
m = get_parameter(module.named_flows, key)
if m is not None:
log.info(f"Zeroing out parameter: [cyan italic]{key}[/].")
with torch.no_grad(): # TODO: remove this and add in root apply call
m.zero_()
m.grad = None
zeroed_keys.append(key)
except:
break
all_zeroed_keys = ",".join(zeroed_keys)
log.info(f"Zeroing out parameters: [cyan italic]\[{all_zeroed_keys}][/].")

0 comments on commit de25301

Please sign in to comment.