Skip to content

Commit

Permalink
move imports
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Dec 16, 2024
1 parent 044ab99 commit 31d3120
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
15 changes: 6 additions & 9 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@
from _pytest.fixtures import SubRequest

from lightning_ir import BiEncoderModule, LightningIRDataModule, LightningIRModule, LightningIRTrainer, RunDataset
from lightning_ir.lightning_utils.callbacks import (
IndexCallback,
RegisterLocalDatasetCallback,
ReRankCallback,
SearchCallback,
)
from lightning_ir.callbacks import IndexCallback, RegisterLocalDatasetCallback, ReRankCallback, SearchCallback
from lightning_ir.retrieve import (
FaissFlatIndexConfig,
FaissIVFIndexConfig,
Expand Down Expand Up @@ -63,12 +58,14 @@ def test_index_callback(
assert index_callback.indexer.num_embeddings and index_callback.indexer.num_docs
assert index_callback.indexer.num_embeddings >= index_callback.indexer.num_docs

dataset_id = doc_datamodule.inference_datasets[0].dataset_id
index_dir = index_dir / dataset_id
assert (index_dir / "index.faiss").exists() or (index_dir / "index.pt").exists()
assert (index_dir / "doc_ids.txt").exists()
doc_ids_path = index_dir / "doc_ids.txt"
doc_ids = doc_ids_path.read_text().split()
for idx, doc_id in enumerate(doc_ids):
assert doc_id == f"doc_id_{idx+1}"
assert doc_id == f"doc_id_{idx + 1}"
assert (index_dir / "config.json").exists()


Expand All @@ -87,7 +84,7 @@ def get_index(
raise ValueError("Unknown search_config type")
index_dir = DATA_DIR / "indexes" / f"{index_type}-{bi_encoder_module.config.similarity_function}"
if index_dir.exists():
return index_dir / "lightning-ir"
return index_dir

index_callback = IndexCallback(index_config=index_config, index_dir=index_dir)

Expand All @@ -97,7 +94,7 @@ def get_index(
callbacks=[index_callback],
)
trainer.test(bi_encoder_module, datamodule=doc_datamodule)
return index_dir / "lightning-ir"
return index_dir


@pytest.mark.parametrize(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from lightning import LightningDataModule, LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset

from lightning_ir.lightning_utils.lr_schedulers import (
from lightning_ir.schedulers.lr_schedulers import (
ConstantLRSchedulerWithLinearWarmup,
LinearLRSchedulerWithLinearWarmup,
WarmupLRScheduler,
)
from lightning_ir.lightning_utils.schedulers import (
from lightning_ir.schedulers.schedulers import (
GenericConstantSchedulerWithLinearWarmup,
GenericConstantSchedulerWithQuadraticWarmup,
GenericLinearSchedulerWithLinearWarmup,
Expand Down

0 comments on commit 31d3120

Please sign in to comment.