Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/moverseai/moai
Browse files Browse the repository at this point in the history
  • Loading branch information
tzole1155 committed Jun 7, 2024
2 parents 0b4cf05 + de25301 commit fc1ee63
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 5 deletions.
4 changes: 2 additions & 2 deletions moai/data/datasets/generic/npz.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
filename: str = "",
):
self.file = load_npz_file(filename)
log.info(f"Loaded an .npz file producing [{list(self.file.keys())}].")
log.info(f"Loaded an .npz file producing {list(self.file.keys())}.")

def __len__(self) -> int:
return len(self.file[toolz.first(self.file)])
Expand All @@ -63,7 +63,7 @@ def __init__(
):
self.file = load_npz_file(filename)
self.length = length
log.info(f"Loaded an .npz file producing [{list(self.file.keys())}].")
log.info(f"Loaded an .npz file producing {list(self.file.keys())}.")

def __len__(self) -> int:
return self.length
Expand Down
46 changes: 46 additions & 0 deletions moai/engine/progressbar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import rich.progress
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme

__all__ = ["MoaiProgressBar"]

# progress_bar = RichProgressBar(
# theme=RichProgressBarTheme(
# description="green_yellow",
# progress_bar="green1",
# progress_bar_finished="green1",
# progress_bar_pulse="#6206E0",
# batch_progress="green_yellow",
# time="grey82",
# processing_speed="grey82",
# metrics="grey82",
# metrics_text_delimiter="\n",
# metrics_format=".3e",
# )


# NOTE: check https://github.com/Textualize/rich/discussions/482
# NOTE: check https://github.com/facebookresearch/EGG/blob/a139946a73d45553360a7f897626d1ae20759f12/egg/core/callbacks.py#L335
# NOTE: check https://github.com/Textualize/rich/discussions/921
class MoaiProgressBar(RichProgressBar):
def __init__(self) -> None:
super().__init__(
theme=RichProgressBarTheme(metrics_text_delimiter="|"),
)

# return [
# TextColumn("[progress.description]{task.description}"),
# CustomBarColumn(
# complete_style=self.theme.progress_bar,
# finished_style=self.theme.progress_bar_finished,
# pulse_style=self.theme.progress_bar_pulse,
# ),
# BatchesProcessedColumn(style=self.theme.batch_progress),
# CustomTimeColumn(style=self.theme.time),
# ProcessingSpeedColumn(style=self.theme.processing_speed),
# ]
def configure_columns(self, trainer: "pl.Trainer") -> list:
original = super().configure_columns(trainer)
moai_column = rich.progress.TextColumn(":moai:")
spinner_column = rich.progress.SpinnerColumn(spinner_name="dots5")
return [moai_column, spinner_column] + original
4 changes: 3 additions & 1 deletion moai/engine/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytorch_lightning as L
from omegaconf.omegaconf import DictConfig

from moai.engine.progressbar import MoaiProgressBar
from moai.engine.run_callback import RunCallback

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -92,7 +93,8 @@ def __init__(
[hyu.instantiate(logger) for logger in loggers.values()] if loggers else []
)
pytl_callbacks = [
RunCallback()
RunCallback(),
MoaiProgressBar(),
] # TODO: only when moai model is used, should not be used for custom models
pytl_callbacks.extend(
[hyu.instantiate(c) for c in callbacks.values()]
Expand Down
5 changes: 4 additions & 1 deletion moai/parameters/initialization/schemes/zero_flow_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ def __init__(
self.keys = keys

def __call__(self, module: torch.nn.Module) -> None:
zeroed_keys = []
for key in self.keys:
try:
m = get_parameter(module.named_flows, key)
if m is not None:
log.info(f"Zeroing out parameter: [cyan italic]{key}[/].")
with torch.no_grad(): # TODO: remove this and add in root apply call
m.zero_()
m.grad = None
zeroed_keys.append(key)
except:
break
all_zeroed_keys = ",".join(zeroed_keys)
log.info(f"Zeroing out parameters: [cyan italic]\[{all_zeroed_keys}][/].")
3 changes: 2 additions & 1 deletion requirements.dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
black==24.4.2
pre-commit==3.7.1
isort==5.13.2
yamllint==1.35.1
yamllint==1.35.1
pytest==8.2.0

0 comments on commit fc1ee63

Please sign in to comment.