Skip to content

Commit

Permalink
Support loading repo artifact from cloned directory
Browse files Browse the repository at this point in the history
  • Loading branch information
Innixma committed Nov 1, 2024
1 parent b9380dc commit 3b0a932
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
3 changes: 3 additions & 0 deletions tabrepo/repository/evaluation_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,14 @@ def from_dir(
cls,
path: str,
prediction_format: Literal["memmap", "memopt", "mem"] = "memmap",
update_relative_path: bool = True,
) -> Self:
from tabrepo.contexts.context import BenchmarkContext

path_context = str(Path(path) / "context.json")
context = BenchmarkContext.from_json(path=path_context)
if update_relative_path:
context.benchmark_paths.relative_path = str(Path(path))

repo = context.load_repo(prediction_format=prediction_format)
return repo
Expand Down
45 changes: 44 additions & 1 deletion tst/test_repository.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy

import shutil
from typing import Callable

import numpy as np
Expand Down Expand Up @@ -305,6 +305,49 @@ def test_repository_save_load():
verify_equivalent_repository(repo1=repo_float64, repo2=repo_loaded_float64, verify_ensemble=True, exact=False)


def test_repository_save_load_with_moving_files():
"""test repo save and load work when moving files to different directories"""

save_path = "tmp_repo"
copy_path = "tmp_repo_copy"
shutil.rmtree(save_path, ignore_errors=True)
shutil.rmtree(copy_path, ignore_errors=True)

repo = load_repo_artificial(include_hyperparameters=True)

repo.to_dir(path=save_path)
repo_loaded = EvaluationRepository.from_dir(path=save_path)
repo_loaded_mem = EvaluationRepository.from_dir(path=save_path, prediction_format="mem")
repo_loaded_memopt = EvaluationRepository.from_dir(path=save_path, prediction_format="memopt")
verify_equivalent_repository(repo1=repo, repo2=repo_loaded, verify_ensemble=True, exact=True)
verify_equivalent_repository(repo1=repo, repo2=repo_loaded_mem, verify_ensemble=True, exact=True)
verify_equivalent_repository(repo1=repo, repo2=repo_loaded_memopt, verify_ensemble=True, exact=True)

shutil.copytree(save_path, copy_path)

repo_loaded_copy = EvaluationRepository.from_dir(path=copy_path)
verify_equivalent_repository(repo1=repo, repo2=repo_loaded_copy, verify_ensemble=True, exact=True)

# verify that the original stops working after deleting the original files
repo_loaded.predict_test(dataset="abalone", fold=0, config=repo_loaded.configs()[0])
shutil.rmtree(save_path)
with pytest.raises(FileNotFoundError):
repo_loaded.predict_test(dataset="abalone", fold=0, config=repo_loaded.configs()[0])

# verify in-memory repos don't require the original files
verify_equivalent_repository(repo1=repo, repo2=repo_loaded_mem, verify_ensemble=True, exact=True)
verify_equivalent_repository(repo1=repo, repo2=repo_loaded_memopt, verify_ensemble=True, exact=True)

# verify that the copy works even after deleting the original files
verify_equivalent_repository(repo1=repo, repo2=repo_loaded_copy, verify_ensemble=True, exact=True)

# verify that the copy stops working after deleting the copied files
repo_loaded_copy.predict_test(dataset="abalone", fold=0, config=repo_loaded_copy.configs()[0])
shutil.rmtree(copy_path)
with pytest.raises(FileNotFoundError):
repo_loaded_copy.predict_test(dataset="abalone", fold=0, config=repo_loaded_copy.configs()[0])


def _assert_predict_multi_binary_as_multiclass(repo, fun: Callable, dataset, configs, n_rows, n_classes):
problem_type = repo.dataset_info(dataset=dataset)["problem_type"]
predict_multi = fun(dataset=dataset, fold=2, configs=configs)
Expand Down

0 comments on commit 3b0a932

Please sign in to comment.