-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from Project-Resilience/predictors
Added predictors and their serializers
- Loading branch information
Showing
20 changed files
with
776 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,10 @@ | ||
coverage==7.6.0 | ||
flake8==7.1.0 | ||
huggingface_hub==0.24.3 | ||
joblib==1.2.0 | ||
numpy==1.23.5 | ||
pandas==1.5.3 | ||
pylint==3.2.6 | ||
pylint==3.2.6 | ||
scikit-learn==1.2.2 | ||
tensorboard==2.13.0 | ||
torch==2.3.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
""" | ||
Immutable NamedTuple for storing the context, actions, and outcomes for a given project. | ||
Note: We choose to use NamedTuple over dataclasses because NamedTuple is immutable. | ||
""" | ||
from typing import NamedTuple | ||
|
||
|
||
class CAOMapping(NamedTuple): | ||
""" | ||
Class defining the context, actions, and outcomes for a given project. | ||
""" | ||
context: list[str] | ||
actions: list[str] | ||
outcomes: list[str] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
""" | ||
A simple custom PyTorch dataset is created here. This is used to keep our | ||
datasets standard between models. It is used in both Torch prescription | ||
and Neural Network training. | ||
""" | ||
import numpy as np | ||
import torch | ||
from torch.utils.data.dataset import Dataset | ||
|
||
|
||
class TorchDataset(Dataset): | ||
""" | ||
Simple custom torch dataset. | ||
:param X: data | ||
:param y: labels | ||
""" | ||
def __init__(self, X: np.ndarray, y: np.ndarray, device="cpu"): | ||
super().__init__() | ||
self.X = torch.tensor(X, dtype=torch.float32, device=device) | ||
self.y = torch.tensor(y, device=device) | ||
assert len(self.X) == len(self.y), "X and y must have the same length" | ||
|
||
def __len__(self): | ||
return len(self.X) | ||
|
||
def __getitem__(self, idx: int) -> tuple: | ||
return self.X[idx], self.y[idx] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
83 changes: 83 additions & 0 deletions
83
src/prsdk/persistence/serializers/neural_network_serializer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
""" | ||
Serializer for the Neural Network Predictor class. | ||
""" | ||
import json | ||
from pathlib import Path | ||
|
||
import joblib | ||
import torch | ||
|
||
from data.cao_mapping import CAOMapping | ||
from persistence.serializers.serializer import Serializer | ||
from predictors.neural_network.torch_neural_net import TorchNeuralNet | ||
from predictors.neural_network.neural_net_predictor import NeuralNetPredictor | ||
|
||
|
||
class NeuralNetSerializer(Serializer): | ||
""" | ||
Serializer for the NeuralNetPredictor. | ||
Saves config necessary to recreate the model, the model itself, and the scaler for the data to a folder. | ||
""" | ||
def save(self, model: NeuralNetPredictor, path: Path): | ||
""" | ||
Saves model, config, and scaler into format for loading. | ||
Generates path to folder if it does not exist. | ||
:param model: the neural network predictor to save. | ||
:param path: path to folder to save model files. | ||
""" | ||
if model.model is None: | ||
raise ValueError("Model not fitted yet.") | ||
path.mkdir(parents=True, exist_ok=True) | ||
|
||
# Note: we don't save the model's device, as it's not guaranteed to be available on load | ||
config = { | ||
"context": model.cao.context, | ||
"actions": model.cao.actions, | ||
"outcomes": model.cao.outcomes, | ||
"features": model.features, | ||
"label": model.label, | ||
"hidden_sizes": model.hidden_sizes, | ||
"linear_skip": model.linear_skip, | ||
"dropout": model.dropout, | ||
"epochs": model.epochs, | ||
"batch_size": model.batch_size, | ||
"optim_params": model.optim_params, | ||
"train_pct": model.train_pct, | ||
"step_lr_params": model.step_lr_params | ||
} | ||
with open(path / "config.json", "w", encoding="utf-8") as file: | ||
json.dump(config, file) | ||
# Put model on CPU before saving | ||
model.model.to("cpu") | ||
torch.save(model.model.state_dict(), path / "model.pt") | ||
joblib.dump(model.scaler, path / "scaler.joblib") | ||
|
||
def load(self, path: Path) -> NeuralNetPredictor: | ||
""" | ||
Loads a model from a given folder. Creates empty model with config, then loads model state dict and scaler. | ||
NOTE: We don't put the model back on the device it was trained on. This has to be done manually. | ||
:param path: path to folder containing model files. | ||
""" | ||
if not path.exists() or not path.is_dir(): | ||
raise FileNotFoundError(f"Path {path} does not exist.") | ||
if not (path / "config.json").exists() or \ | ||
not (path / "model.pt").exists() or \ | ||
not (path / "scaler.joblib").exists(): | ||
raise FileNotFoundError("Model files not found in path.") | ||
|
||
# Initialize model with config | ||
with open(path / "config.json", "r", encoding="utf-8") as file: | ||
config = json.load(file) | ||
# Grab CAO out of config | ||
cao = CAOMapping(config.pop("context"), config.pop("actions"), config.pop("outcomes")) | ||
nnp = NeuralNetPredictor(cao, config) | ||
|
||
nnp.model = TorchNeuralNet(len(config["features"]), | ||
config["hidden_sizes"], | ||
config["linear_skip"], | ||
config["dropout"]) | ||
# Set map_location to CPU to avoid issues with GPU availability | ||
nnp.model.load_state_dict(torch.load(path / "model.pt", map_location="cpu")) | ||
nnp.model.eval() | ||
nnp.scaler = joblib.load(path / "scaler.joblib") | ||
return nnp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
""" | ||
Serializer for the SKLearnPredictor class. | ||
""" | ||
import json | ||
from pathlib import Path | ||
|
||
import joblib | ||
|
||
from data.cao_mapping import CAOMapping | ||
from persistence.serializers.serializer import Serializer | ||
from predictors.sklearn_predictors.sklearn_predictor import SKLearnPredictor | ||
|
||
|
||
class SKLearnSerializer(Serializer): | ||
""" | ||
Serializer for the SKLearnPredictor. | ||
Uses joblib to save the model and json to save the config used to load it. | ||
""" | ||
def save(self, model: SKLearnPredictor, path: Path): | ||
""" | ||
Saves saves model and features into format for loading. | ||
Generates path to folder if it does not exist. | ||
:param path: path to folder to save model files. | ||
""" | ||
path.mkdir(parents=True, exist_ok=True) | ||
|
||
# Add CAO to the config | ||
config = dict(model.config.items()) | ||
cao_dict = {"context": model.cao.context, "actions": model.cao.actions, "outcomes": model.cao.outcomes} | ||
config.update(cao_dict) | ||
|
||
with open(path / "config.json", "w", encoding="utf-8") as file: | ||
json.dump(config, file) | ||
joblib.dump(model.model, path / "model.joblib") | ||
|
||
def load(self, path: Path) -> "SKLearnPredictor": | ||
""" | ||
Loads saved model and config from a local folder. | ||
:param path: path to folder to load model files from. | ||
""" | ||
load_path = Path(path) | ||
if not load_path.exists() or not load_path.is_dir(): | ||
raise FileNotFoundError(f"Path {path} does not exist.") | ||
if not (load_path / "config.json").exists() or not (load_path / "model.joblib").exists(): | ||
raise FileNotFoundError("Model files not found in path.") | ||
|
||
# Extract CAO from config | ||
with open(load_path / "config.json", "r", encoding="utf-8") as file: | ||
config = json.load(file) | ||
cao = CAOMapping(config.pop("context"), config.pop("actions"), config.pop("outcomes")) | ||
|
||
model = joblib.load(load_path / "model.joblib") | ||
sklearn_predictor = SKLearnPredictor(cao, model, config) | ||
return sklearn_predictor |
Oops, something went wrong.