diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..39fbbc8 --- /dev/null +++ b/TODO.md @@ -0,0 +1,3 @@ +- test the training: model should be able to reach a certain loss on known data, with fixed seed +- test the ckpt loading in CI, it's cheap. do it on CPU if possible +- enforce ruff/pyright compliance on main branch? diff --git a/train/config.py b/train/config.py index 0ce56ba..b086320 100644 --- a/train/config.py +++ b/train/config.py @@ -1,3 +1,4 @@ +import torch from dataclasses import dataclass, field import multiprocessing as mp from typing import Optional @@ -7,48 +8,57 @@ @dataclass class DataConfig: root: str = str(Path.home() / "datasets/imagenette") + tiny_dataset_size: Optional[int] = None # If set, use N images each for train/val train_val_split: float = 0.9 num_workers: int = field(default_factory=lambda: min(mp.cpu_count(), 4)) pin_memory: bool = False @dataclass -class BaseTrainConfig: - # Model +class TrainingConfig: + # Core behavior flags + peak_performance_test: bool = False # New flag for peak performance testing + compile: bool = True + amp: bool = True + profiler: bool = False + + # Model architecture patch_size: int = 16 embed_dim: int = 192 # Tiny variant decoder_embed_dim: int = 96 - compile: bool = True - # Training + # Training parameters batch_size: int = 256 total_samples: int = 1_000_000 - amp: bool = True grad_clip: float = 1.0 mask_ratio: float = 0.75 - lr: float = 1.5e-4 - seed: Optional[int] = None # override for reproducibility - profiler: bool = False - lr_layer_decay: float = ( - 1 # layerwise LR decay. qualitatively different behavior if <1 - ) + # Optimization parameters + lr: float = 1.5e-4 + lr_layer_decay: float = 1.0 # layerwise LR decay warmup_ratio: float = 0.1 weight_decay: float = 0.05 beta1: float = 0.9 beta2: float = 0.999 - # Data + # Random seed + seed: Optional[int] = None + + # Data configuration data: DataConfig = field(default_factory=DataConfig) - # Logging + device: str = field( + default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu" + ) + + # Logging and checkpointing log_dir: str = "runs" ckpt_dir: str = "checkpoints" samples_per_viz: int = 1000 samples_per_val: int = 10000 samples_per_ckpt: int = 50000 - # From original MAE-Lite repo. + # Pretrained model pretrained_path: Optional[str] = "ckpt/mae_tiny_400e.pth.tar" def __post_init__(self): diff --git a/train/config/base.yaml b/train/config/base.yaml new file mode 100644 index 0000000..1aa50be --- /dev/null +++ b/train/config/base.yaml @@ -0,0 +1,38 @@ +defaults: + - base_schema # Add this line at the top + - _self_ # Keep existing defaults + +data: + root: ${oc.env:HOME}/datasets/imagenette + train_val_split: 0.9 + num_workers: 4 + pin_memory: false + +compile: true +amp: true +profiler: false + +patch_size: 16 +embed_dim: 192 +decoder_embed_dim: 96 + +batch_size: 256 +total_samples: 1_000_000 +grad_clip: 1.0 +mask_ratio: 0.75 + +lr: 1.5e-4 +lr_layer_decay: 1.0 +warmup_ratio: 0.1 +weight_decay: 0.05 +beta1: 0.9 +beta2: 0.999 + +log_dir: runs +ckpt_dir: checkpoints +samples_per_viz: 1000 +samples_per_val: 10000 +samples_per_ckpt: 50000 + +peak_performance_test: false +seed: 42 diff --git a/train/config/data_bound.yaml b/train/config/data_bound.yaml index 9d4f56d..f7b80cd 100644 --- a/train/config/data_bound.yaml +++ b/train/config/data_bound.yaml @@ -1,5 +1,5 @@ defaults: - - base_train_config # Reference the schema + - base - _self_ # Training parameters diff --git a/train/config/few_images.yaml b/train/config/few_images.yaml new file mode 100644 index 0000000..5707f19 --- /dev/null +++ b/train/config/few_images.yaml @@ -0,0 +1,23 @@ +defaults: + - base + - _self_ + +# Training parameters +batch_size: 4 +data: + tiny_dataset_size: 4 + +total_samples: 2000 +grad_clip: 0.5 +mask_ratio: 0.75 +lr: 8e-5 +lr_layer_decay: 1 +warmup_ratio: 0.5 +weight_decay: 1e-4 +beta1: 0.95 +beta2: 0.9995 + +# Logging and checkpointing +samples_per_viz: 100 +samples_per_val: 100 +samples_per_ckpt: 100_000 diff --git a/train/config/large_batch.yaml b/train/config/large_batch.yaml index 0146582..b0af3d3 100644 --- a/train/config/large_batch.yaml +++ b/train/config/large_batch.yaml @@ -1,5 +1,5 @@ defaults: - - base_train_config + - base - _self_ # Training parameters diff --git a/train/config/peak.yaml b/train/config/peak.yaml new file mode 100644 index 0000000..2701225 --- /dev/null +++ b/train/config/peak.yaml @@ -0,0 +1,16 @@ +defaults: + - base + - _self_ + +peak_performance_test: true +batch_size: 512 # Larger batch for throughput +total_samples: 100_000 # Shorter run for perf testing +profiler: true # Enable profiling + +# Disable overhead +samples_per_viz: 999999999 +samples_per_val: 999999999 +samples_per_ckpt: 999999999 + +compile: true +amp: true diff --git a/train/config/config.yaml b/train/config/preload.yaml similarity index 57% rename from train/config/config.yaml rename to train/config/preload.yaml index deddc39..0e21dd5 100644 --- a/train/config/config.yaml +++ b/train/config/preload.yaml @@ -1,11 +1,17 @@ defaults: - - base_train_config + - base - _self_ # Training parameters batch_size: 128 -total_samples: 100_000 -lr: 3e-4 +data: + tiny_dataset_size: 256 + +total_samples: 100_000 # don't need that many to recover color +grad_clip: 1.0 +mask_ratio: 0.75 +lr: 1e-4 +lr_layer_decay: 1 warmup_ratio: 0.1 weight_decay: 0.01 beta1: 0.9 diff --git a/train/config/sweep.yaml b/train/config/sweep.yaml index f61db28..338eb51 100644 --- a/train/config/sweep.yaml +++ b/train/config/sweep.yaml @@ -1,5 +1,5 @@ defaults: - - config + - base - _self_ # Overrides for sweep. diff --git a/train/data.py b/train/data.py index 9c3970e..68d0662 100644 --- a/train/data.py +++ b/train/data.py @@ -1,12 +1,14 @@ -from typing import Tuple import torch -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader, Subset from torchvision import datasets, transforms +import logging + +logger = logging.getLogger(__name__) IMAGENETTE_STATS = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]} -def create_transforms(train: bool, size: int = 224) -> transforms.Compose: +def create_transform(train: bool, size: int = 224) -> transforms.Compose: """Create transform pipeline with efficient resize.""" resize_size = int(size * 1.143) # Slightly larger for random crops @@ -30,51 +32,89 @@ def create_transforms(train: bool, size: int = 224) -> transforms.Compose: ) -def get_dataloaders(cfg) -> Tuple[DataLoader, DataLoader]: - """Get dataloaders with optimized settings.""" - # Create transforms - train_transform = create_transforms(True, cfg.patch_size * 14) # 16 * 14 = 224 - val_transform = create_transforms(False, cfg.patch_size * 14) +class TensorCycler: + """Provides DataLoader-like iteration over a fixed tensor.""" - # Load dataset - full_dataset = datasets.ImageFolder(root=cfg.data.root, transform=train_transform) - val_dataset = datasets.ImageFolder(root=cfg.data.root, transform=val_transform) + def __init__(self, images: torch.Tensor, batch_size: int, device: str = "cuda"): + self.images = images + self.batch_size = batch_size + self.device = device + self.length = len(images) + assert ( + self.batch_size <= self.length + ), f"Batch size ({batch_size}) cannot be larger than dataset size ({self.length})" - # Split datasets - train_size = int(len(full_dataset) * cfg.data.train_val_split) - val_size = len(full_dataset) - train_size + def __iter__(self): + while True: # Infinite iteration + idx = torch.randperm(self.length, device=self.device)[: self.batch_size] + yield self.images[idx], torch.zeros(self.batch_size, device=self.device) - if cfg.seed is not None: - generator = torch.Generator().manual_seed(cfg.seed) - else: - generator = None + def __len__(self): + return self.length // self.batch_size - train_dataset, _ = random_split( - full_dataset, [train_size, val_size], generator=generator - ) - _, val_dataset = random_split( - val_dataset, [train_size, val_size], generator=generator - ) - # Create dataloaders with optimized settings - train_loader = DataLoader( - train_dataset, - batch_size=cfg.batch_size, - shuffle=True, - num_workers=cfg.data.num_workers, - pin_memory=cfg.data.pin_memory, - persistent_workers=True if cfg.data.num_workers > 0 else False, - prefetch_factor=2 if cfg.data.num_workers > 0 else None, +def get_preprocessed_tensors(dataset, indices, device="cuda") -> torch.Tensor: + """Process specific indices through dataset and store results.""" + loader = DataLoader( + Subset(dataset, indices), + batch_size=len(indices), # Process all at once + num_workers=1, + shuffle=False, ) + # Single batch processing + images, _ = next(iter(loader)) + return images.to(device) - val_loader = DataLoader( - val_dataset, - batch_size=cfg.batch_size, - shuffle=False, - num_workers=cfg.data.num_workers, - pin_memory=cfg.data.pin_memory, - persistent_workers=True if cfg.data.num_workers > 0 else False, - prefetch_factor=2 if cfg.data.num_workers > 0 else None, + +def get_dataloaders(cfg): + """Get either normal dataloaders or tensor cyclers.""" + if cfg.data.tiny_dataset_size is None: + # Normal training path + train_transform = create_transform(True, cfg.patch_size * 14) + val_transform = create_transform(False, cfg.patch_size * 14) + + train_dataset = datasets.ImageFolder(cfg.data.root, transform=train_transform) + val_dataset = datasets.ImageFolder(cfg.data.root, transform=val_transform) + + logger.info( + f"Creating regular dataloaders with {len(train_dataset)} total images" + ) + + return map( + lambda ds: DataLoader( + ds, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=cfg.data.num_workers, + pin_memory=cfg.data.pin_memory, + persistent_workers=cfg.data.num_workers > 0, + prefetch_factor=2 if cfg.data.num_workers > 0 else None, + ), + (train_dataset, val_dataset), + ) + + # Tiny dataset path + N = cfg.data.tiny_dataset_size + logger.info(f"Creating tiny dataset with {N} images each for train/val") + + # Create transforms and base datasets + transforms_list = [ + create_transform(is_train, cfg.patch_size * 14) for is_train in (True, False) + ] + + datasets_list = list( + map(lambda t: datasets.ImageFolder(cfg.data.root, transform=t), transforms_list) ) - return train_loader, val_loader + # Get separate indices for train and val + all_indices = torch.randperm(len(datasets_list[0])) + indices_list = [all_indices[:N].tolist(), all_indices[N : 2 * N].tolist()] + + # Process images and create cyclers + return map( + lambda ds, idx: TensorCycler( + get_preprocessed_tensors(ds, idx, cfg.device), cfg.batch_size, cfg.device + ), + datasets_list, + indices_list, + ) diff --git a/train/main.py b/train/main.py index b5d9c5c..e8d7344 100644 --- a/train/main.py +++ b/train/main.py @@ -2,7 +2,7 @@ import time from pathlib import Path from dataclasses import dataclass -from typing import Optional, cast +from typing import Optional import hydra import numpy as np @@ -14,35 +14,41 @@ from torch.utils.data import DataLoader from fml.model import MAEConfig, MAELite -from train.config import BaseTrainConfig +from train.config import TrainingConfig from train.data import get_dataloaders -from train.training import MAETrainer # Using our refactored trainer +from train.training import MAETrainer from train.opt import create_optimizer_and_scheduler -torch.set_float32_matmul_precision("high") +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +# Register configs with Hydra cs = ConfigStore.instance() -cs.store(name="base_train_config", node=BaseTrainConfig) +cs.store(name="base_schema", node=TrainingConfig) @dataclass class TrainingResources: + """Container for training resources to simplify resource management.""" + model: MAELite - train_loader: DataLoader - val_loader: DataLoader + train_loader: Optional[DataLoader] + val_loader: Optional[DataLoader] writer: SummaryWriter profiler: Optional[torch.profiler.profile] = None - -def setup_paths(cfg: BaseTrainConfig) -> None: - for path_attr in ["log_dir", "ckpt_dir"]: - path = to_absolute_path(getattr(cfg, path_attr.split(".")[-1])) - Path(path).mkdir(parents=True, exist_ok=True) - setattr(cfg, path_attr.split(".")[-1], path) + def cleanup(self): + """Cleanup resources to prevent memory leaks.""" + if self.profiler: + self.profiler.stop() + self.writer.close() -def create_model( - cfg: BaseTrainConfig, device: torch.device, logger: logging.Logger -) -> MAELite: +def create_model(cfg: TrainingConfig, device: torch.device) -> MAELite: + """Create and configure the model.""" t0 = time.perf_counter() model_cfg = MAEConfig( @@ -55,18 +61,27 @@ def create_model( model = MAELite(model_cfg, device) if cfg.pretrained_path: - model.load_legacy_weights(cfg.pretrained_path) + path = to_absolute_path(cfg.pretrained_path) + if Path(path).exists(): + logger.info(f"Loading pretrained weights from {path}") + model.load_legacy_weights(path) + else: + logger.warning( + f"Pretrained weights path {path} not found, starting from scratch" + ) if cfg.compile: t1 = time.perf_counter() + logger.info("Compiling model...") model = torch.compile(model) - logger.info(f"Compilation: {time.perf_counter() - t1:.2f}s") + logger.info(f"Compilation took {time.perf_counter() - t1:.2f}s") - logger.info(f"Model creation: {time.perf_counter() - t0:.2f}s") - return cast(MAELite, model) + logger.info(f"Model creation took {time.perf_counter() - t0:.2f}s") + return model def setup_profiler(log_dir: str) -> torch.profiler.profile: + """Configure the PyTorch profiler.""" return profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], schedule=schedule(wait=1, warmup=1, active=3), @@ -77,23 +92,41 @@ def setup_profiler(log_dir: str) -> torch.profiler.profile: ) +def setup_paths(cfg: TrainingConfig) -> None: + """Setup and validate logging and checkpoint directories.""" + for path_attr in ["log_dir", "ckpt_dir"]: + path = to_absolute_path(getattr(cfg, path_attr)) + Path(path).mkdir(parents=True, exist_ok=True) + setattr(cfg, path_attr, path) + + def initialize_training( - cfg: BaseTrainConfig, + cfg: TrainingConfig, device: torch.device, - logger: logging.Logger, ) -> TrainingResources: - model = create_model(cfg, device, logger) - + """Initialize all training resources.""" t0 = time.perf_counter() - train_loader, val_loader = get_dataloaders(cfg) - logger.info( - f"Dataset sizes (in batches) - Train: {len(train_loader)}, Val: {len(val_loader)}" - ) - logger.info(f"Dataloader setup: {time.perf_counter() - t0:.2f}s") + # Create model + model = create_model(cfg, device) + + # Setup data loading + if cfg.peak_performance_test: + logger.info("🚀 Running in peak performance test mode - bypassing data loading") + train_loader = val_loader = None + else: + logger.info("Setting up data loaders...") + train_loader, val_loader = get_dataloaders(cfg) + logger.info( + f"Dataset sizes - Train: {len(train_loader)}, Val: {len(val_loader)} batches" + ) + + # Setup logging and profiling writer = SummaryWriter(cfg.log_dir) profiler = setup_profiler(cfg.log_dir) if cfg.profiler else None + logger.info(f"Initialization took {time.perf_counter() - t0:.2f}s") + return TrainingResources( model=model, train_loader=train_loader, @@ -103,55 +136,78 @@ def initialize_training( ) -@hydra.main(version_base="1.2", config_path="config", config_name="config") -def main(cfg: BaseTrainConfig) -> None: +@hydra.main(version_base="1.2", config_path="config", config_name="base") +def main(cfg: TrainingConfig) -> None: + """Main training entry point.""" t_start = time.perf_counter() - logging.basicConfig(level=logging.INFO) - logger = logging.getLogger(__name__) - logger.info("Initializing training pipeline...") - + # Set random seeds for reproducibility if cfg.seed is not None: torch.manual_seed(cfg.seed) np.random.seed(cfg.seed) + logger.info(f"Set random seed to {cfg.seed}") + # Setup directories setup_paths(cfg) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - resources = initialize_training(cfg, device, logger) - logger.info(f"Setup complete in {time.perf_counter() - t_start:.2f}s") - - if resources.profiler: - resources.profiler.start() - - # Create optimizer and scheduler - total_steps = cfg.total_samples // cfg.batch_size - warmup_steps = int(total_steps * cfg.warmup_ratio) - logger.info( - f"Preparing to train for {total_steps} steps, warming up over {warmup_steps}" - ) - optimizer, scheduler = create_optimizer_and_scheduler( - logger, resources.model, cfg, total_steps - ) - - # Pass scheduler to trainer - trainer = MAETrainer( - model=resources.model, - train_loader=resources.train_loader, - val_loader=resources.val_loader, - optimizer=optimizer, - scheduler=scheduler, # Add this - cfg=cfg, - device=device, - logger=logger, - profiler=resources.profiler, - ) - - trainer.train() - if resources.profiler: - resources.profiler.stop() - resources.writer.close() + # Select device + device = torch.device(cfg.device) + logger.info(f"Using device: {device}") + + if device.type == "cuda": + torch.set_float32_matmul_precision("high") + logger.info(f"CUDA devices: {torch.cuda.device_count()}") + logger.info(f"CUDA capabilities: {torch.cuda.get_device_capability()}") + + # Initialize training resources + resources = initialize_training(cfg, device) + logger.info(f"Setup completed in {time.perf_counter() - t_start:.2f}s") + + try: + # Start profiler if enabled + if resources.profiler: + resources.profiler.start() + + # Create optimizer and scheduler + total_steps = cfg.total_samples // cfg.batch_size + warmup_steps = int(total_steps * cfg.warmup_ratio) + logger.info( + f"Preparing for {total_steps} steps" + f"{', with ' + str(warmup_steps) + ' warmup steps' if warmup_steps > 0 else ''}" + ) + + optimizer, scheduler = create_optimizer_and_scheduler( + logger, resources.model, cfg, warmup_steps + ) + + # Create and run trainer + trainer = MAETrainer( + model=resources.model, + train_loader=resources.train_loader, + val_loader=resources.val_loader, + optimizer=optimizer, + scheduler=scheduler, + cfg=cfg, + device=device, + logger=logger, + profiler=resources.profiler, + ) + + if cfg.peak_performance_test: + trainer.train_peak_performance() + else: + trainer.train() + + except KeyboardInterrupt: + logger.info("Training interrupted by user") + except Exception as e: + logger.exception(f"Training failed with error: {str(e)}") + raise + finally: + # Cleanup + resources.cleanup() + + logger.info(f"Run completed in {time.perf_counter() - t_start:.2f}s") if __name__ == "__main__": diff --git a/train/training.py b/train/training.py index 3bc4900..a5ef4fb 100644 --- a/train/training.py +++ b/train/training.py @@ -13,6 +13,7 @@ from fml.model import MAELite from fml.utils import denorm +from .data import TensorCycler @dataclass @@ -56,6 +57,9 @@ def __init__( self.samples_since_viz = 0 self.best_val_loss = float("inf") + self.val_vis_images = None + self.val_vis_batch = None + @property def step_idx(self) -> int: return self.imgs_seen // self.cfg.batch_size @@ -83,22 +87,6 @@ def load_checkpoint(self, path: Path) -> None: self.scheduler.load_state_dict(ckpt["scheduler"]) self.scaler.load_state_dict(ckpt["scaler"]) - @torch.no_grad() - def validation_step(self, images: torch.Tensor) -> StepOutput: - self.model.eval() - torch.manual_seed(0) # Consistent masking for validation - images = images.to(self.device, non_blocking=True) - - with autocast("cuda", enabled=self.cfg.amp): - loss, pred, mask, latent = self.model( - images, mask_ratio=self.cfg.mask_ratio - ) - reconstructed = self.model.unpatchify(pred) - - return StepOutput( - loss=loss.item(), reconstructed=reconstructed, mask=mask, latent=latent - ) - def training_step(self, images: torch.Tensor) -> StepOutput: images = images.to(self.device, non_blocking=True) @@ -178,15 +166,41 @@ def validate(self) -> float: self.model.eval() total_loss = 0.0 total_samples = 0 - vis_images = next(iter(self.val_loader))[0] + # Get validation iterator + if isinstance(self.val_loader, TensorCycler): + num_val_batches = 50 + validation_iter = ( + next(iter(self.val_loader)) for _ in range(num_val_batches) + ) + + # For TensorCycler, reuse same images but allow random masks + if self.val_vis_images is None: + # Just take the first N images from our fixed tensor + self.val_vis_images = self.val_loader.images[ + : min(4, len(self.val_loader.images)) + ] + else: + validation_iter = self.val_loader + num_val_batches = len(self.val_loader) + + # For regular DataLoader, store first batch we see for consistent visualization + if self.val_vis_images is None: + images, _ = next(iter(self.val_loader)) + self.val_vis_images = images[: min(4, len(images))] + + # Run validation for batch_idx, (images, _) in enumerate( - tqdm(self.val_loader, desc="Validating", leave=False) + tqdm(validation_iter, desc="Validating", total=num_val_batches, leave=False) ): out = self.validation_step(images) + # Log first batch visualization if batch_idx == 0: - self.log_visuals("val", vis_images, out) + # For vis, run a fresh forward pass on our stored images + # This preserves random masking while keeping images consistent + vis_out = self.validation_step(self.val_vis_images) + self.log_visuals("val", self.val_vis_images, vis_out) batch_size = images.shape[0] total_loss += out.loss * batch_size @@ -196,6 +210,22 @@ def validate(self) -> float: self.log_metrics(StepOutput(avg_loss, None, None, None), prefix="val") return avg_loss + @torch.no_grad() + def validation_step(self, images: torch.Tensor) -> StepOutput: + """Run validation step with no seed manipulation.""" + self.model.eval() + images = images.to(self.device, non_blocking=True) + + with autocast("cuda", enabled=self.cfg.amp): + loss, pred, mask, latent = self.model( + images, mask_ratio=self.cfg.mask_ratio + ) + reconstructed = self.model.unpatchify(pred) + + return StepOutput( + loss=loss.item(), reconstructed=reconstructed, mask=mask, latent=latent + ) + def train(self) -> None: self.model.train() val_loss = self.validate() @@ -267,3 +297,55 @@ def train(self) -> None: self.logger.info( f"Peak memory: {torch.cuda.max_memory_allocated() / 1e9:.2f}GB" ) + + def train_peak_performance(self): + """Train with constant tensor to measure peak performance.""" + self.model.train() + + # Create fixed training tensor + B = self.cfg.batch_size + H = W = self.cfg.patch_size * 14 # 224 for patch_size=16 + images = torch.randn(B, 3, H, W, device=self.device) + + # Setup timing + t_start = time.perf_counter() + + # Core training loop + total_steps = self.cfg.total_samples // B + pbar = tqdm(total=self.cfg.total_samples, desc="Peak Training", unit="img") + + for step in range(total_steps): + with autocast("cuda", enabled=self.cfg.amp): + loss, pred, mask, latent = self.model( + images, mask_ratio=self.cfg.mask_ratio + ) + + self.scaler.scale(loss).backward() + + if self.cfg.grad_clip > 0: + self.scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.cfg.grad_clip + ) + + self.scaler.step(self.optimizer) + self.scaler.update() + if self.scheduler: + self.scheduler.step() + self.optimizer.zero_grad(set_to_none=True) + + # Update progress + self.imgs_seen += B + pbar.update(B) + + torch.cuda.synchronize() + wall_time = time.perf_counter() - t_start + + self.logger.info("Peak Performance Results:") + self.logger.info(f"Total steps: {total_steps}") + self.logger.info(f"Wall time: {wall_time:.2f}s") + self.logger.info(f"Steps per second: {total_steps / wall_time:.1f}") + self.logger.info(f"Images per second: {(total_steps * B) / wall_time:.1f}") + self.logger.info( + f"Peak memory: {torch.cuda.max_memory_allocated() / 1e9:.2f}GB" + )