Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log trained SAEs to wandb artifacts #23

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 47 additions & 10 deletions training.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
"""
Training dictionaries
"""

import json
import multiprocessing as mp
import os
Expand All @@ -17,14 +13,33 @@
from .trainers.standard import StandardTrainer


def save_checkpoint(wandb_run, model_path, config_path, name, step):
# Create and log artifact
artifact = wandb.Artifact(
name=name,
type="model",
description=f"Model checkpoint at step {step}",
)
artifact.add_file(model_path)
artifact.add_file(config_path)
wandb_run.log_artifact(artifact)

print(f"Model and config saved as artifact at step {step}")


def new_wandb_process(config, log_queue, entity, project):
wandb.init(entity=entity, project=project, config=config, name=config["wandb_name"])
while True:
try:
log = log_queue.get(timeout=1)
if log == "DONE":
break
wandb.log(log)
if isinstance(log, dict) and log.get("artifact", False):
# Handle artifact saving
artifact_data = log["artifact_data"]
save_checkpoint(wandb_run=wandb.run, **artifact_data)
else:
wandb.log(log)
except Empty:
continue
wandb.finish()
Expand Down Expand Up @@ -88,7 +103,7 @@ def trainSAE(
run_cfg={},
):
"""
Train SAEs using the given trainers
Train SAEs using the given trainers and save them as wandb artifacts
"""
trainers = []
for config in trainer_configs:
Expand Down Expand Up @@ -141,23 +156,45 @@ def trainSAE(

# saving
if save_steps is not None and step % save_steps == 0:
for dir, trainer in zip(save_dirs, trainers):
for dir, trainer, log_queue in zip(save_dirs, trainers, log_queues):
if dir is not None:
if not os.path.exists(os.path.join(dir, "checkpoints")):
os.mkdir(os.path.join(dir, "checkpoints"))
save_path = os.path.join(dir, "checkpoints", f"ae_{step}.pt")
t.save(
trainer.ae.state_dict(),
os.path.join(dir, "checkpoints", f"ae_{step}.pt"),
save_path,
)
config_path = os.path.join(dir, "config.json")
# Send message to wandb process to save artifact
if use_wandb:
# Prepare artifact data
artifact_data = {
"model_path": save_path,
"config_path": config_path,
"name": f"{trainer.config.get('wandb_name', 'trainer')}_{step}",
"step": step,
}
log_queue.put({"artifact": True, "artifact_data": artifact_data})

# training
for trainer in trainers:
trainer.update(step, act)

# save final SAEs
for save_dir, trainer in zip(save_dirs, trainers):
for save_dir, trainer, log_queue in zip(save_dirs, trainers, log_queues):
if save_dir is not None:
t.save(trainer.ae.state_dict(), os.path.join(save_dir, "ae.pt"))
save_path = os.path.join(save_dir, "ae.pt")
t.save(trainer.ae.state_dict(), save_path)
config_path = os.path.join(save_dir, "config.json")
if use_wandb:
artifact_data = {
"model_path": save_path,
"config_path": config_path,
"name": f"{trainer.config.get('wandb_name', 'trainer')}_final",
"step": 'final',
}
log_queue.put({"artifact": True, "artifact_data": artifact_data})

# Signal wandb processes to finish
if use_wandb:
Expand Down