-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added predictors and their serializers #13
Conversation
src/prsdk/data/cao_mapping.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an immutable struct that we can use to store context, actions, and outcomes inside our models. This allows us to check inputs/outputs/whatever when we compare models, etc. in project-specific contexts.
config = { | ||
"context": model.cao.context, | ||
"actions": model.cao.actions, | ||
"outcomes": model.cao.outcomes, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We store the context, actions, and outcomes in our serialization now
with open(path / "config.json", "r", encoding="utf-8") as file: | ||
config = json.load(file) | ||
# Grab CAO out of config | ||
cao = CAOMapping(config.pop("context"), config.pop("actions"), config.pop("outcomes")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also reconstruct our cao in loading
# Extract CAO from config | ||
with open(load_path / "config.json", "r", encoding="utf-8") as file: | ||
config = json.load(file) | ||
cao = CAOMapping(config.pop("context"), config.pop("actions"), config.pop("outcomes")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We reconstruct cao when loading
# Add CAO to the config | ||
config = dict(model.config.items()) | ||
cao_dict = {"context": model.cao.context, "actions": model.cao.actions, "outcomes": model.cao.outcomes} | ||
config.update(cao_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dump our cao into the config now
Data is automatically standardized and the scaler is saved with the model. | ||
TODO: We want to be able to have custom scaling in the future. | ||
""" | ||
def __init__(self, cao: CAOMapping, model_config: dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New cao arg to pass in to every predictor
import torch | ||
|
||
|
||
class TorchNeuralNet(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed to TorchNeuralNet from ELUCNeuralNet because it's task-agnostic
|
||
class Predictor(ABC): | ||
""" | ||
Abstract class for predictors to inherit from. | ||
Predictors must be able to be fit and predict on a DataFrame. | ||
It is up to the Predictor to keep track of the proper label to label the output DataFrame. | ||
""" | ||
def __init__(self, context: list[str], actions: list[str], outcomes: list[str]): | ||
def __init__(self, cao: CAOMapping): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of taking 3 manual lists, we now just pass the singular object
Simple abstract class for sklearn predictors. | ||
Keeps track of features fit on and label to predict. | ||
""" | ||
def __init__(self, cao: CAOMapping, model, model_config: dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Takes cao now instead of 3 distinct lists
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unit tests for persistence. We don't do any saving though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same tests as before
config["linear_skip"], | ||
config["dropout"]) | ||
# Set map_location to CPU to avoid issues with GPU availability | ||
nnp.model.load_state_dict(torch.load(path / "model.pt", map_location="cpu")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We set the map location to CPU to avoid errors if we're loading from a state dict that was saved while on a different device. This is technically not necessary because we move to CPU on save but helps with backward-compatibility
with open(path / "config.json", "w", encoding="utf-8") as file: | ||
json.dump(config, file) | ||
# Put model on CPU before saving | ||
model.model.to("cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move to CPU so that we don't error if we go from M1 to NVIDIA or change architectures like that
…h may no longer match
…ethod to manually set model's device.
Fixed #12
Transferred over the predictors from MVP. Updated them all to hold on to a CAO mapping object since we can't hard-code that anymore.
Added some unit tests for hf persistor.