-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from Project-Resilience/abc
Added abstract classes for predictor/prescriptor persistor/serializer
- Loading branch information
Showing
9 changed files
with
152 additions
and
21 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
dist/ | ||
*.egg-info/ | ||
__pycache__/ | ||
__pycache__/ | ||
.coverage | ||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
coverage==7.6.0 | ||
flake8==7.1.0 | ||
pandas==1.5.3 | ||
pylint==3.2.6 |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
""" | ||
Persistor abstract class. Wraps a serializer and provides an interface for persisting models | ||
(ex to HuggingFace) and loading models from a persistence location. | ||
""" | ||
from pathlib import Path | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
from persistence.serializers.serializer import Serializer | ||
|
||
|
||
class Persistor(ABC): | ||
""" | ||
Abstract class for persistors to inherit from. | ||
Wraps around a serializer to cache the persisted models onto disk before loading them. | ||
""" | ||
def __init__(self, serializer: Serializer): | ||
self.serializer = serializer | ||
|
||
@abstractmethod | ||
def persist(self, model, model_path: Path, repo_id: str, **persistence_args): | ||
""" | ||
Serializes a model using the serializer, then uploads the model to a persistence location. | ||
:param model: The python object model to persist. | ||
:param model_path: The path on disk to save the model to before persisting it. | ||
:param repo_id: The ID used to point to the model in whatever method we use to persist it. | ||
:param persistence_args: Additional arguments to pass to the persistence method. | ||
""" | ||
raise NotImplementedError("Persisting not implemented") | ||
|
||
@abstractmethod | ||
def from_pretrained(self, path_or_url: str, **persistence_args): | ||
""" | ||
Loads a model from where it was persisted from. | ||
:param path_or_url: The path or URL to load the model from. | ||
:param persistence_args: Additional arguments to pass to the loading method. | ||
""" | ||
raise NotImplementedError("Loading not implemented") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
""" | ||
Abstract class responsible for defining the interface of the serializer classes. | ||
""" | ||
from abc import ABC, abstractmethod | ||
from pathlib import Path | ||
|
||
|
||
class Serializer(ABC): | ||
""" | ||
Abstract class responsible for saving and loading predictor/prescriptor models locally. | ||
Save and load should be compatible with each other but don't necessarily have to be the same as other models. | ||
Save should take an object and save it to a path. | ||
Load should take a path and return an object. | ||
""" | ||
@abstractmethod | ||
def save(self, model, path: Path) -> None: | ||
""" | ||
Saves a model to disk. | ||
:param model: The model as a python object to save. | ||
:param path: The path to save the model to. | ||
""" | ||
raise NotImplementedError("Saving not implemented") | ||
|
||
@abstractmethod | ||
def load(self, path: Path): | ||
""" | ||
Takes a path and returns a model. | ||
:param path: The path to load the model from. | ||
""" | ||
raise NotImplementedError("Loading not implemented") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
""" | ||
Abstract class for predictors to inherit from. | ||
""" | ||
from abc import ABC, abstractmethod | ||
|
||
import pandas as pd | ||
|
||
|
||
class Predictor(ABC): | ||
""" | ||
Abstract class for predictors to inherit from. | ||
Predictors must be able to be fit and predict on a DataFrame. | ||
It is up to the Predictor to keep track of the proper label to label the output DataFrame. | ||
""" | ||
def __init__(self, context: list[str], actions: list[str], outcomes: list[str]): | ||
""" | ||
Initializes the Predictor with the context, actions, and outcomes. | ||
:param context: list of context columns | ||
:param actions: list of action columns | ||
:param outcomes: list of outcome columns | ||
""" | ||
self.context = context | ||
self.actions = actions | ||
self.outcomes = outcomes | ||
|
||
@abstractmethod | ||
def fit(self, X_train: pd.DataFrame, y_train: pd.Series): | ||
""" | ||
Fits the model to the training data. | ||
:param X_train: DataFrame with input data: | ||
The input data consists of a DataFrame with context and actions. | ||
It is up to the model to decide which columns to use. | ||
:param y_train: pandas Series with target data. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def predict(self, context_actions_df: pd.DataFrame) -> pd.DataFrame: | ||
""" | ||
Creates a DataFrame with predictions for the input DataFrame. | ||
The Predictor model is expected to keep track of the label so that it can label the output | ||
DataFrame properly. Additionally, the output DataFrame must have the same index as the input DataFrame. | ||
:param context_actions_df: DataFrame with context and actions input data. | ||
:return: DataFrame with predictions | ||
""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
""" | ||
Abstract prescriptor class to be implemented. | ||
""" | ||
from abc import ABC, abstractmethod | ||
|
||
import pandas as pd | ||
|
||
|
||
class Prescriptor(ABC): | ||
""" | ||
Abstract class for prescriptors to allow us to experiment with different implementations. | ||
""" | ||
def __init__(self, context: list[str], actions: list[str]): | ||
# We keep track of the context and actions to ensure that the prescriptor is compatible with the environment. | ||
self.context = context | ||
self.actions = actions | ||
|
||
@abstractmethod | ||
def prescribe(self, context_df: pd.DataFrame) -> pd.DataFrame: | ||
""" | ||
Takes in a context dataframe and prescribes actions. | ||
Outputs a concatenation of the context and actions. | ||
:param context_df: A dataframe containing rows of context data. | ||
:return: A dataframe containing the context and the prescribed actions. | ||
""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,17 @@ | ||
""" | ||
Dummy test for the dummy module. | ||
Dummy test to test Github actions. | ||
""" | ||
import unittest | ||
|
||
from src.prsdk.dummy import compute_percent_change | ||
|
||
|
||
class TestDummy(unittest.TestCase): | ||
""" | ||
Tests for the dummy module. | ||
A fake test that always returns true. | ||
""" | ||
def test_pct_change(self): | ||
def test_dummy(self): | ||
""" | ||
Tests the compute_percent_change function. | ||
It should return the input divided by 100. | ||
A test that always returns true. | ||
""" | ||
self.assertEqual(compute_percent_change(100), 1.0) | ||
self.assertEqual(compute_percent_change(50), 0.5) | ||
self.assertEqual(compute_percent_change(0), 0.0) | ||
# pylint: disable=redundant-unittest-assert | ||
self.assertTrue(True) | ||
# pylint: enable=redundant-unittest-assert |