From 3b0a9325a33d452c917700a6bb0144f376a82733 Mon Sep 17 00:00:00 2001 From: innixma Date: Fri, 1 Nov 2024 20:05:08 +0000 Subject: [PATCH] Support loading repo artifact from cloned directory --- tabrepo/repository/evaluation_repository.py | 3 ++ tst/test_repository.py | 45 ++++++++++++++++++++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/tabrepo/repository/evaluation_repository.py b/tabrepo/repository/evaluation_repository.py index 0f379788..7d43dafc 100644 --- a/tabrepo/repository/evaluation_repository.py +++ b/tabrepo/repository/evaluation_repository.py @@ -291,11 +291,14 @@ def from_dir( cls, path: str, prediction_format: Literal["memmap", "memopt", "mem"] = "memmap", + update_relative_path: bool = True, ) -> Self: from tabrepo.contexts.context import BenchmarkContext path_context = str(Path(path) / "context.json") context = BenchmarkContext.from_json(path=path_context) + if update_relative_path: + context.benchmark_paths.relative_path = str(Path(path)) repo = context.load_repo(prediction_format=prediction_format) return repo diff --git a/tst/test_repository.py b/tst/test_repository.py index dd4f0782..af07fbae 100644 --- a/tst/test_repository.py +++ b/tst/test_repository.py @@ -1,5 +1,5 @@ import copy - +import shutil from typing import Callable import numpy as np @@ -305,6 +305,49 @@ def test_repository_save_load(): verify_equivalent_repository(repo1=repo_float64, repo2=repo_loaded_float64, verify_ensemble=True, exact=False) +def test_repository_save_load_with_moving_files(): + """test repo save and load work when moving files to different directories""" + + save_path = "tmp_repo" + copy_path = "tmp_repo_copy" + shutil.rmtree(save_path, ignore_errors=True) + shutil.rmtree(copy_path, ignore_errors=True) + + repo = load_repo_artificial(include_hyperparameters=True) + + repo.to_dir(path=save_path) + repo_loaded = EvaluationRepository.from_dir(path=save_path) + repo_loaded_mem = EvaluationRepository.from_dir(path=save_path, prediction_format="mem") + repo_loaded_memopt = EvaluationRepository.from_dir(path=save_path, prediction_format="memopt") + verify_equivalent_repository(repo1=repo, repo2=repo_loaded, verify_ensemble=True, exact=True) + verify_equivalent_repository(repo1=repo, repo2=repo_loaded_mem, verify_ensemble=True, exact=True) + verify_equivalent_repository(repo1=repo, repo2=repo_loaded_memopt, verify_ensemble=True, exact=True) + + shutil.copytree(save_path, copy_path) + + repo_loaded_copy = EvaluationRepository.from_dir(path=copy_path) + verify_equivalent_repository(repo1=repo, repo2=repo_loaded_copy, verify_ensemble=True, exact=True) + + # verify that the original stops working after deleting the original files + repo_loaded.predict_test(dataset="abalone", fold=0, config=repo_loaded.configs()[0]) + shutil.rmtree(save_path) + with pytest.raises(FileNotFoundError): + repo_loaded.predict_test(dataset="abalone", fold=0, config=repo_loaded.configs()[0]) + + # verify in-memory repos don't require the original files + verify_equivalent_repository(repo1=repo, repo2=repo_loaded_mem, verify_ensemble=True, exact=True) + verify_equivalent_repository(repo1=repo, repo2=repo_loaded_memopt, verify_ensemble=True, exact=True) + + # verify that the copy works even after deleting the original files + verify_equivalent_repository(repo1=repo, repo2=repo_loaded_copy, verify_ensemble=True, exact=True) + + # verify that the copy stops working after deleting the copied files + repo_loaded_copy.predict_test(dataset="abalone", fold=0, config=repo_loaded_copy.configs()[0]) + shutil.rmtree(copy_path) + with pytest.raises(FileNotFoundError): + repo_loaded_copy.predict_test(dataset="abalone", fold=0, config=repo_loaded_copy.configs()[0]) + + def _assert_predict_multi_binary_as_multiclass(repo, fun: Callable, dataset, configs, n_rows, n_classes): problem_type = repo.dataset_info(dataset=dataset)["problem_type"] predict_multi = fun(dataset=dataset, fold=2, configs=configs)