From 31d3120521a969d4c6db2323a64a1e7b353de4ca Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Mon, 16 Dec 2024 07:30:30 +0100 Subject: [PATCH] move imports --- tests/test_callbacks.py | 15 ++++++--------- tests/test_schedulers.py | 4 ++-- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index c165d11..54ca132 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -7,12 +7,7 @@ from _pytest.fixtures import SubRequest from lightning_ir import BiEncoderModule, LightningIRDataModule, LightningIRModule, LightningIRTrainer, RunDataset -from lightning_ir.lightning_utils.callbacks import ( - IndexCallback, - RegisterLocalDatasetCallback, - ReRankCallback, - SearchCallback, -) +from lightning_ir.callbacks import IndexCallback, RegisterLocalDatasetCallback, ReRankCallback, SearchCallback from lightning_ir.retrieve import ( FaissFlatIndexConfig, FaissIVFIndexConfig, @@ -63,12 +58,14 @@ def test_index_callback( assert index_callback.indexer.num_embeddings and index_callback.indexer.num_docs assert index_callback.indexer.num_embeddings >= index_callback.indexer.num_docs + dataset_id = doc_datamodule.inference_datasets[0].dataset_id + index_dir = index_dir / dataset_id assert (index_dir / "index.faiss").exists() or (index_dir / "index.pt").exists() assert (index_dir / "doc_ids.txt").exists() doc_ids_path = index_dir / "doc_ids.txt" doc_ids = doc_ids_path.read_text().split() for idx, doc_id in enumerate(doc_ids): - assert doc_id == f"doc_id_{idx+1}" + assert doc_id == f"doc_id_{idx + 1}" assert (index_dir / "config.json").exists() @@ -87,7 +84,7 @@ def get_index( raise ValueError("Unknown search_config type") index_dir = DATA_DIR / "indexes" / f"{index_type}-{bi_encoder_module.config.similarity_function}" if index_dir.exists(): - return index_dir / "lightning-ir" + return index_dir index_callback = IndexCallback(index_config=index_config, index_dir=index_dir) @@ -97,7 +94,7 @@ def get_index( callbacks=[index_callback], ) trainer.test(bi_encoder_module, datamodule=doc_datamodule) - return index_dir / "lightning-ir" + return index_dir @pytest.mark.parametrize( diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py index 65a74b2..ef3c1a5 100644 --- a/tests/test_schedulers.py +++ b/tests/test_schedulers.py @@ -5,12 +5,12 @@ from lightning import LightningDataModule, LightningModule, Trainer from torch.utils.data import DataLoader, Dataset -from lightning_ir.lightning_utils.lr_schedulers import ( +from lightning_ir.schedulers.lr_schedulers import ( ConstantLRSchedulerWithLinearWarmup, LinearLRSchedulerWithLinearWarmup, WarmupLRScheduler, ) -from lightning_ir.lightning_utils.schedulers import ( +from lightning_ir.schedulers.schedulers import ( GenericConstantSchedulerWithLinearWarmup, GenericConstantSchedulerWithQuadraticWarmup, GenericLinearSchedulerWithLinearWarmup,