Skip to content

Commit

Permalink
Merge pull request #6 from Project-Resilience/abc
Browse files Browse the repository at this point in the history
Added abstract classes for predictor/prescriptor persistor/serializer
  • Loading branch information
danyoungday authored Jul 29, 2024
2 parents d2f1930 + 59e18a5 commit 368bcb0
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 21 deletions.
File renamed without changes.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dist/
*.egg-info/
__pycache__/
__pycache__/
.coverage
.vscode
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
coverage==7.6.0
flake8==7.1.0
pandas==1.5.3
pylint==3.2.6
10 changes: 0 additions & 10 deletions src/prsdk/dummy.py

This file was deleted.

38 changes: 38 additions & 0 deletions src/prsdk/persistence/persistors/persistor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Persistor abstract class. Wraps a serializer and provides an interface for persisting models
(ex to HuggingFace) and loading models from a persistence location.
"""
from pathlib import Path

from abc import ABC, abstractmethod

from persistence.serializers.serializer import Serializer


class Persistor(ABC):
"""
Abstract class for persistors to inherit from.
Wraps around a serializer to cache the persisted models onto disk before loading them.
"""
def __init__(self, serializer: Serializer):
self.serializer = serializer

@abstractmethod
def persist(self, model, model_path: Path, repo_id: str, **persistence_args):
"""
Serializes a model using the serializer, then uploads the model to a persistence location.
:param model: The python object model to persist.
:param model_path: The path on disk to save the model to before persisting it.
:param repo_id: The ID used to point to the model in whatever method we use to persist it.
:param persistence_args: Additional arguments to pass to the persistence method.
"""
raise NotImplementedError("Persisting not implemented")

@abstractmethod
def from_pretrained(self, path_or_url: str, **persistence_args):
"""
Loads a model from where it was persisted from.
:param path_or_url: The path or URL to load the model from.
:param persistence_args: Additional arguments to pass to the loading method.
"""
raise NotImplementedError("Loading not implemented")
30 changes: 30 additions & 0 deletions src/prsdk/persistence/serializers/serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
Abstract class responsible for defining the interface of the serializer classes.
"""
from abc import ABC, abstractmethod
from pathlib import Path


class Serializer(ABC):
"""
Abstract class responsible for saving and loading predictor/prescriptor models locally.
Save and load should be compatible with each other but don't necessarily have to be the same as other models.
Save should take an object and save it to a path.
Load should take a path and return an object.
"""
@abstractmethod
def save(self, model, path: Path) -> None:
"""
Saves a model to disk.
:param model: The model as a python object to save.
:param path: The path to save the model to.
"""
raise NotImplementedError("Saving not implemented")

@abstractmethod
def load(self, path: Path):
"""
Takes a path and returns a model.
:param path: The path to load the model from.
"""
raise NotImplementedError("Loading not implemented")
46 changes: 46 additions & 0 deletions src/prsdk/predictors/predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
Abstract class for predictors to inherit from.
"""
from abc import ABC, abstractmethod

import pandas as pd


class Predictor(ABC):
"""
Abstract class for predictors to inherit from.
Predictors must be able to be fit and predict on a DataFrame.
It is up to the Predictor to keep track of the proper label to label the output DataFrame.
"""
def __init__(self, context: list[str], actions: list[str], outcomes: list[str]):
"""
Initializes the Predictor with the context, actions, and outcomes.
:param context: list of context columns
:param actions: list of action columns
:param outcomes: list of outcome columns
"""
self.context = context
self.actions = actions
self.outcomes = outcomes

@abstractmethod
def fit(self, X_train: pd.DataFrame, y_train: pd.Series):
"""
Fits the model to the training data.
:param X_train: DataFrame with input data:
The input data consists of a DataFrame with context and actions.
It is up to the model to decide which columns to use.
:param y_train: pandas Series with target data.
"""
raise NotImplementedError

@abstractmethod
def predict(self, context_actions_df: pd.DataFrame) -> pd.DataFrame:
"""
Creates a DataFrame with predictions for the input DataFrame.
The Predictor model is expected to keep track of the label so that it can label the output
DataFrame properly. Additionally, the output DataFrame must have the same index as the input DataFrame.
:param context_actions_df: DataFrame with context and actions input data.
:return: DataFrame with predictions
"""
raise NotImplementedError
26 changes: 26 additions & 0 deletions src/prsdk/prescriptors/prescriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
Abstract prescriptor class to be implemented.
"""
from abc import ABC, abstractmethod

import pandas as pd


class Prescriptor(ABC):
"""
Abstract class for prescriptors to allow us to experiment with different implementations.
"""
def __init__(self, context: list[str], actions: list[str]):
# We keep track of the context and actions to ensure that the prescriptor is compatible with the environment.
self.context = context
self.actions = actions

@abstractmethod
def prescribe(self, context_df: pd.DataFrame) -> pd.DataFrame:
"""
Takes in a context dataframe and prescribes actions.
Outputs a concatenation of the context and actions.
:param context_df: A dataframe containing rows of context data.
:return: A dataframe containing the context and the prescribed actions.
"""
raise NotImplementedError
17 changes: 7 additions & 10 deletions tests/test_dummy.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
"""
Dummy test for the dummy module.
Dummy test to test Github actions.
"""
import unittest

from src.prsdk.dummy import compute_percent_change


class TestDummy(unittest.TestCase):
"""
Tests for the dummy module.
A fake test that always returns true.
"""
def test_pct_change(self):
def test_dummy(self):
"""
Tests the compute_percent_change function.
It should return the input divided by 100.
A test that always returns true.
"""
self.assertEqual(compute_percent_change(100), 1.0)
self.assertEqual(compute_percent_change(50), 0.5)
self.assertEqual(compute_percent_change(0), 0.0)
# pylint: disable=redundant-unittest-assert
self.assertTrue(True)
# pylint: enable=redundant-unittest-assert

0 comments on commit 368bcb0

Please sign in to comment.