Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: How far can overfitting to a few images get us? #3

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- test the training: model should be able to reach a certain loss on known data, with fixed seed
- test the ckpt loading in CI, it's cheap. do it on CPU if possible
- enforce ruff/pyright compliance on main branch?
38 changes: 24 additions & 14 deletions train/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from dataclasses import dataclass, field
import multiprocessing as mp
from typing import Optional
Expand All @@ -7,48 +8,57 @@
@dataclass
class DataConfig:
root: str = str(Path.home() / "datasets/imagenette")
tiny_dataset_size: Optional[int] = None # If set, use N images each for train/val
train_val_split: float = 0.9
num_workers: int = field(default_factory=lambda: min(mp.cpu_count(), 4))
pin_memory: bool = False


@dataclass
class BaseTrainConfig:
# Model
class TrainingConfig:
# Core behavior flags
peak_performance_test: bool = False # New flag for peak performance testing
compile: bool = True
amp: bool = True
profiler: bool = False

# Model architecture
patch_size: int = 16
embed_dim: int = 192 # Tiny variant
decoder_embed_dim: int = 96
compile: bool = True

# Training
# Training parameters
batch_size: int = 256
total_samples: int = 1_000_000
amp: bool = True
grad_clip: float = 1.0
mask_ratio: float = 0.75
lr: float = 1.5e-4
seed: Optional[int] = None # override for reproducibility
profiler: bool = False

lr_layer_decay: float = (
1 # layerwise LR decay. qualitatively different behavior if <1
)
# Optimization parameters
lr: float = 1.5e-4
lr_layer_decay: float = 1.0 # layerwise LR decay
warmup_ratio: float = 0.1
weight_decay: float = 0.05
beta1: float = 0.9
beta2: float = 0.999

# Data
# Random seed
seed: Optional[int] = None

# Data configuration
data: DataConfig = field(default_factory=DataConfig)

# Logging
device: str = field(
default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu"
)

# Logging and checkpointing
log_dir: str = "runs"
ckpt_dir: str = "checkpoints"
samples_per_viz: int = 1000
samples_per_val: int = 10000
samples_per_ckpt: int = 50000

# From original MAE-Lite repo.
# Pretrained model
pretrained_path: Optional[str] = "ckpt/mae_tiny_400e.pth.tar"

def __post_init__(self):
Expand Down
38 changes: 38 additions & 0 deletions train/config/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
defaults:
- base_schema # Add this line at the top
- _self_ # Keep existing defaults

data:
root: ${oc.env:HOME}/datasets/imagenette
train_val_split: 0.9
num_workers: 4
pin_memory: false

compile: true
amp: true
profiler: false

patch_size: 16
embed_dim: 192
decoder_embed_dim: 96

batch_size: 256
total_samples: 1_000_000
grad_clip: 1.0
mask_ratio: 0.75

lr: 1.5e-4
lr_layer_decay: 1.0
warmup_ratio: 0.1
weight_decay: 0.05
beta1: 0.9
beta2: 0.999

log_dir: runs
ckpt_dir: checkpoints
samples_per_viz: 1000
samples_per_val: 10000
samples_per_ckpt: 50000

peak_performance_test: false
seed: 42
2 changes: 1 addition & 1 deletion train/config/data_bound.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defaults:
- base_train_config # Reference the schema
- base
- _self_

# Training parameters
Expand Down
23 changes: 23 additions & 0 deletions train/config/few_images.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
defaults:
- base
- _self_

# Training parameters
batch_size: 4
data:
tiny_dataset_size: 4

total_samples: 2000
grad_clip: 0.5
mask_ratio: 0.75
lr: 8e-5
lr_layer_decay: 1
warmup_ratio: 0.5
weight_decay: 1e-4
beta1: 0.95
beta2: 0.9995

# Logging and checkpointing
samples_per_viz: 100
samples_per_val: 100
samples_per_ckpt: 100_000
2 changes: 1 addition & 1 deletion train/config/large_batch.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defaults:
- base_train_config
- base
- _self_

# Training parameters
Expand Down
16 changes: 16 additions & 0 deletions train/config/peak.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
defaults:
- base
- _self_

peak_performance_test: true
batch_size: 512 # Larger batch for throughput
total_samples: 100_000 # Shorter run for perf testing
profiler: true # Enable profiling

# Disable overhead
samples_per_viz: 999999999
samples_per_val: 999999999
samples_per_ckpt: 999999999

compile: true
amp: true
12 changes: 9 additions & 3 deletions train/config/config.yaml → train/config/preload.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
defaults:
- base_train_config
- base
- _self_

# Training parameters
batch_size: 128
total_samples: 100_000
lr: 3e-4
data:
tiny_dataset_size: 256

total_samples: 100_000 # don't need that many to recover color
grad_clip: 1.0
mask_ratio: 0.75
lr: 1e-4
lr_layer_decay: 1
warmup_ratio: 0.1
weight_decay: 0.01
beta1: 0.9
Expand Down
2 changes: 1 addition & 1 deletion train/config/sweep.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defaults:
- config
- base
- _self_

# Overrides for sweep.
Expand Down
124 changes: 82 additions & 42 deletions train/data.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Tuple
import torch
from torch.utils.data import DataLoader, random_split
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import logging

logger = logging.getLogger(__name__)

IMAGENETTE_STATS = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}


def create_transforms(train: bool, size: int = 224) -> transforms.Compose:
def create_transform(train: bool, size: int = 224) -> transforms.Compose:
"""Create transform pipeline with efficient resize."""
resize_size = int(size * 1.143) # Slightly larger for random crops

Expand All @@ -30,51 +32,89 @@ def create_transforms(train: bool, size: int = 224) -> transforms.Compose:
)


def get_dataloaders(cfg) -> Tuple[DataLoader, DataLoader]:
"""Get dataloaders with optimized settings."""
# Create transforms
train_transform = create_transforms(True, cfg.patch_size * 14) # 16 * 14 = 224
val_transform = create_transforms(False, cfg.patch_size * 14)
class TensorCycler:
"""Provides DataLoader-like iteration over a fixed tensor."""

# Load dataset
full_dataset = datasets.ImageFolder(root=cfg.data.root, transform=train_transform)
val_dataset = datasets.ImageFolder(root=cfg.data.root, transform=val_transform)
def __init__(self, images: torch.Tensor, batch_size: int, device: str = "cuda"):
self.images = images
self.batch_size = batch_size
self.device = device
self.length = len(images)
assert (
self.batch_size <= self.length
), f"Batch size ({batch_size}) cannot be larger than dataset size ({self.length})"

# Split datasets
train_size = int(len(full_dataset) * cfg.data.train_val_split)
val_size = len(full_dataset) - train_size
def __iter__(self):
while True: # Infinite iteration
idx = torch.randperm(self.length, device=self.device)[: self.batch_size]
yield self.images[idx], torch.zeros(self.batch_size, device=self.device)

if cfg.seed is not None:
generator = torch.Generator().manual_seed(cfg.seed)
else:
generator = None
def __len__(self):
return self.length // self.batch_size

train_dataset, _ = random_split(
full_dataset, [train_size, val_size], generator=generator
)
_, val_dataset = random_split(
val_dataset, [train_size, val_size], generator=generator
)

# Create dataloaders with optimized settings
train_loader = DataLoader(
train_dataset,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=cfg.data.num_workers,
pin_memory=cfg.data.pin_memory,
persistent_workers=True if cfg.data.num_workers > 0 else False,
prefetch_factor=2 if cfg.data.num_workers > 0 else None,
def get_preprocessed_tensors(dataset, indices, device="cuda") -> torch.Tensor:
"""Process specific indices through dataset and store results."""
loader = DataLoader(
Subset(dataset, indices),
batch_size=len(indices), # Process all at once
num_workers=1,
shuffle=False,
)
# Single batch processing
images, _ = next(iter(loader))
return images.to(device)

val_loader = DataLoader(
val_dataset,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.data.num_workers,
pin_memory=cfg.data.pin_memory,
persistent_workers=True if cfg.data.num_workers > 0 else False,
prefetch_factor=2 if cfg.data.num_workers > 0 else None,

def get_dataloaders(cfg):
"""Get either normal dataloaders or tensor cyclers."""
if cfg.data.tiny_dataset_size is None:
# Normal training path
train_transform = create_transform(True, cfg.patch_size * 14)
val_transform = create_transform(False, cfg.patch_size * 14)

train_dataset = datasets.ImageFolder(cfg.data.root, transform=train_transform)
val_dataset = datasets.ImageFolder(cfg.data.root, transform=val_transform)

logger.info(
f"Creating regular dataloaders with {len(train_dataset)} total images"
)

return map(
lambda ds: DataLoader(
ds,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=cfg.data.num_workers,
pin_memory=cfg.data.pin_memory,
persistent_workers=cfg.data.num_workers > 0,
prefetch_factor=2 if cfg.data.num_workers > 0 else None,
),
(train_dataset, val_dataset),
)

# Tiny dataset path
N = cfg.data.tiny_dataset_size
logger.info(f"Creating tiny dataset with {N} images each for train/val")

# Create transforms and base datasets
transforms_list = [
create_transform(is_train, cfg.patch_size * 14) for is_train in (True, False)
]

datasets_list = list(
map(lambda t: datasets.ImageFolder(cfg.data.root, transform=t), transforms_list)
)

return train_loader, val_loader
# Get separate indices for train and val
all_indices = torch.randperm(len(datasets_list[0]))
indices_list = [all_indices[:N].tolist(), all_indices[N : 2 * N].tolist()]

# Process images and create cyclers
return map(
lambda ds, idx: TensorCycler(
get_preprocessed_tensors(ds, idx, cfg.device), cfg.batch_size, cfg.device
),
datasets_list,
indices_list,
)
Loading