Skip to content

Commit

Permalink
Fix broken pipe issue with upload
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney committed Jan 5, 2024
1 parent ee4171c commit 04ec274
Showing 1 changed file with 33 additions and 48 deletions.
81 changes: 33 additions & 48 deletions sparse_autoencoder/train/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def setup_source_data(hyperparameters: RuntimeHyperparameters) -> SourceDataset:

def setup_wandb() -> RuntimeHyperparameters:
"""Initialise wandb for experiment tracking."""
wandb.run = None # Fix for broken pipe bug in wandb
wandb.init()
return dict(wandb.config) # type: ignore

Expand Down Expand Up @@ -296,58 +297,42 @@ def run_training_pipeline(

def train() -> None:
"""Train the sparse autoencoder using the hyperparameters from the WandB sweep."""
try:
# Set up WandB
hyperparameters = setup_wandb()
run_name: str = wandb.run.name # type: ignore

# Setup the device for training
device = get_device()

# Set up the source model
source_model = setup_source_model(hyperparameters)

# Set up the autoencoder
autoencoder = setup_autoencoder(hyperparameters, device)

# Set up the loss function
loss_function = setup_loss_function(hyperparameters)

# Set up the optimizer
optimizer = setup_optimizer(autoencoder, hyperparameters)

# Set up the activation resampler
activation_resampler = setup_activation_resampler(hyperparameters)

# Set up the source data
source_data = setup_source_data(hyperparameters)

# Run the training pipeline
run_training_pipeline(
hyperparameters=hyperparameters,
source_model=source_model,
autoencoder=autoencoder,
loss=loss_function,
optimizer=optimizer,
activation_resampler=activation_resampler,
source_data=source_data,
run_name=run_name,
)
# Set up WandB
hyperparameters = setup_wandb()
run_name: str = wandb.run.name # type: ignore

# Explicit exception catching needed to show the stack trace in wandb sweeps
except Exception as _exception: # noqa: BLE001
# Format the stack trace
full_stack_trace = traceback.format_exc(50)
# Setup the device for training
device = get_device()

stack_trace = "\n".join(
line for line in full_stack_trace.splitlines() if "wandb/sdk" not in line
)
# Set up the source model
source_model = setup_source_model(hyperparameters)

# Set up the autoencoder
autoencoder = setup_autoencoder(hyperparameters, device)

# Set up the loss function
loss_function = setup_loss_function(hyperparameters)

# Set up the optimizer
optimizer = setup_optimizer(autoencoder, hyperparameters)

# Set up the activation resampler
activation_resampler = setup_activation_resampler(hyperparameters)

# Also print the stack trace to stderr
print(stack_trace, file=sys.stderr) # noqa: T201
# Set up the source data
source_data = setup_source_data(hyperparameters)

# Exit current run with an error code
sys.exit(1)
# Run the training pipeline
run_training_pipeline(
hyperparameters=hyperparameters,
source_model=source_model,
autoencoder=autoencoder,
loss=loss_function,
optimizer=optimizer,
activation_resampler=activation_resampler,
source_data=source_data,
run_name=run_name,
)


def sweep(sweep_config: SweepConfig | None = None, sweep_id: str | None = None) -> None:
Expand Down

0 comments on commit 04ec274

Please sign in to comment.