Skip to content

Commit

Permalink
Merge pull request #8 from Project-Resilience/huggingface
Browse files Browse the repository at this point in the history
Added HuggingFace persistor implementation
  • Loading branch information
danyoungday authored Jul 30, 2024
2 parents 368bcb0 + b964cdf commit 54305ac
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 1 deletion.
4 changes: 3 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@ suggestion-mode=yes
disable=

# Default set of "always good" names
good-names=_
good-names=_,X_train,X_test

recursive=y
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
coverage==7.6.0
flake8==7.1.0
huggingface_hub==0.24.3
pandas==1.5.3
pylint==3.2.6
72 changes: 72 additions & 0 deletions src/prsdk/persistence/persistors/hf_persistor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Persistor for models to and from HuggingFace repo.
"""
from pathlib import Path

from huggingface_hub import HfApi, snapshot_download

from persistence.persistors.persistor import Persistor


class HuggingFacePersistor(Persistor):
"""
Persists models to and from HuggingFace repo.
"""
def write_readme(self, model_path: str):
"""
Writes readme to model save path to upload.
TODO: Need to add more info to the readme and make it a proper template.
"""
model_path = Path(model_path)
with open(model_path / "README.md", "w", encoding="utf-8") as file:
file.write("This is a demo model created for Project Resilience")

def persist(self, model, model_path: Path, repo_id: str, **persistence_args):
"""
Serializes the model to a local path using the file_serializer,
then uploads the model to a HuggingFace repo.
"""
# Save model and write readme
self.serializer.save(model, model_path)
self.write_readme(model_path)

# Get token if it exists
token = persistence_args.get("token", None)

api = HfApi()
# Create repo if it doesn't exist
api.create_repo(
repo_id=repo_id,
repo_type="model",
exist_ok=True,
token=token
)

# Upload model to repo
api.upload_folder(
folder_path=model_path,
repo_id=repo_id,
repo_type="model",
token=token
)

def from_pretrained(self, path_or_url: str, **hf_args):
"""
Loads a model from a HuggingFace repo pointed to by path_or_url.
Defaults to downloading to the HuggingFace cache directory. If you want to download to a different directory,
pass the local_dir argument in hf_args.
:param path_or_url: path to the model or url to the huggingface repo.
:param hf_args: arguments to pass to the snapshot_download function from huggingface.
"""
path = Path(path_or_url)
if path.exists() and path.is_dir():
return self.serializer.load(path)

url_path = path_or_url.replace("/", "--")
local_dir = hf_args.get("local_dir", f"~/.cache/huggingface/project-resilience/{url_path}")

if not Path(local_dir).exists() or not Path(local_dir).is_dir():
hf_args["local_dir"] = local_dir
snapshot_download(repo_id=path_or_url, **hf_args)

return self.serializer.load(Path(local_dir))
2 changes: 2 additions & 0 deletions src/prsdk/prescriptors/prescriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd


# pylint: disable=too-few-public-methods
class Prescriptor(ABC):
"""
Abstract class for prescriptors to allow us to experiment with different implementations.
Expand All @@ -24,3 +25,4 @@ def prescribe(self, context_df: pd.DataFrame) -> pd.DataFrame:
:return: A dataframe containing the context and the prescribed actions.
"""
raise NotImplementedError
# pylint: enable=too-few-public-methods

0 comments on commit 54305ac

Please sign in to comment.