Skip to content

Commit

Permalink
Added model registry and moved training files into src/ folder
Browse files Browse the repository at this point in the history
  • Loading branch information
GeroVanMi committed May 5, 2024
1 parent c32f3db commit 862efdd
Show file tree
Hide file tree
Showing 16 changed files with 28 additions and 27 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Data and Model Checkpoints
data/
models/
frontend/.streamlit/secrets.toml

# W&B log files
wandb/

# Cloud authentication keys
keys/
gha-creds-*.json
.env
frontend/.streamlit/secrets.toml

# IntelliJ IDEs
.idea/
Expand Down
1 change: 1 addition & 0 deletions cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ steps:
"push",
"europe-west1-docker.pkg.dev/$PROJECT_ID/training-images/pokemon-trainer",
]
# Run Lightning Executor
- name: "gcr.io/cloud-builders/docker"
id: RunTrainingPipeline
waitFor:
Expand Down
19 changes: 0 additions & 19 deletions training/setup_gc.sh

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class Configuration:
local_dataset_path = Path("./data/")
training_bucket_name = "zhaw_algorithmic_quartet_training_images"

output_dir = "./models/ddpm-pokemon-128"
model_name = "pokemon-generator"
output_dir = f"./models/{model_name}"

push_to_hub = False # whether to upload the saved model to the HF Hub
hub_private_repo = False
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import wandb
import yaml

Expand Down Expand Up @@ -31,7 +33,7 @@ def create_run(self, experiment_name=None):
"""
# Append mode to experiment name if provided
if experiment_name:
experiment_name = f"{self.mode.upper()}_{experiment_name}"
experiment_name = experiment_name
else:
experiment_name = self.mode.upper()

Expand Down Expand Up @@ -71,6 +73,18 @@ def log_image(self, image, caption):
if self.run and self.mode != "production":
self.run.log({"image": [wandb.Image(image, caption=caption)]})

def link_model(self, model_path: str, model_name: str):
"""
Uploads a model to the W&B model registry. Only works if the given path actually
exists and the W&B run was initalized.
Args:
model_path: Path to the file that stores the trained model.
model_name: How the model should be called in the registry.
"""
if self.run and Path(model_path).exists():
self.run.link_model(path=model_path, registered_model_name=model_name)

def finish_run(self):
"""
Finishes the active W&B run.
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
11 changes: 6 additions & 5 deletions training/pipeline.py → training/src/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import argparse

import torch
from configurations.Configuration import Configuration
from configurations.create_config import create_config_from_arguments
Expand All @@ -25,9 +23,7 @@ def prepare_data(config: Configuration):
config.training_bucket_name, max_results=config.num_images
)

return load_dataset(
str(config.local_dataset_path.resolve()), split="train[0:6]"
)
return load_dataset(str(config.local_dataset_path.resolve()), split="train")

if not config.local_dataset_path.exists():
download_bucket_with_transfer_manager(config.training_bucket_name)
Expand Down Expand Up @@ -113,6 +109,11 @@ def transform(examples):
wandb_config,
)

wandb_config.link_model(
config.output_dir,
config.model_name,
)


if __name__ == "__main__":
config = create_config_from_arguments()
Expand Down
File renamed without changes.

0 comments on commit 862efdd

Please sign in to comment.