Skip to content

Commit

Permalink
reorganize indexers and searchers
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Dec 10, 2024
1 parent 8cdd245 commit 7f21aef
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 24 deletions.
11 changes: 5 additions & 6 deletions lightning_ir/retrieve/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .faiss_indexer import (
from .base import IndexConfig, Indexer, SearchConfig, Searcher
from .faiss import (
FaissFlatIndexConfig,
FaissFlatIndexer,
FaissIVFIndexConfig,
Expand All @@ -7,12 +8,10 @@
FaissIVFPQIndexer,
FaissPQIndexConfig,
FaissPQIndexer,
FaissSearchConfig,
FaissSearcher,
)
from .faiss_searcher import FaissSearchConfig, FaissSearcher
from .indexer import IndexConfig, Indexer
from .searcher import SearchConfig, Searcher
from .sparse_indexer import SparseIndexConfig, SparseIndexer
from .sparse_searcher import SparseSearchConfig, SparseSearcher
from .sparse import SparseIndexConfig, SparseIndexer, SparseSearchConfig, SparseSearcher

__all__ = [
"FaissFlatIndexConfig",
Expand Down
9 changes: 9 additions & 0 deletions lightning_ir/retrieve/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .indexer import IndexConfig, Indexer
from .searcher import SearchConfig, Searcher

__all__ = [
"IndexConfig",
"Indexer",
"SearchConfig",
"Searcher",
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List

import torch

if TYPE_CHECKING:
from ..bi_encoder import BiEncoderConfig, BiEncoderOutput
from ..data import IndexBatch
from ...bi_encoder import BiEncoderConfig, BiEncoderOutput
from ...data import IndexBatch


class Indexer(ABC):
Expand All @@ -24,7 +24,7 @@ def __init__(
self.index_dir = index_dir
self.index_config = index_config
self.bi_encoder_config = bi_encoder_config
self.doc_ids = []
self.doc_ids: List[str] = []
self.doc_lengths = array.array("I")
self.num_embeddings = 0
self.num_docs = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

import torch

from ..bi_encoder.model import BiEncoderEmbedding
from ...bi_encoder.model import BiEncoderEmbedding

if TYPE_CHECKING:
from ..bi_encoder import BiEncoderModule, BiEncoderOutput
from ...bi_encoder import BiEncoderModule, BiEncoderOutput


class Searcher(ABC):
Expand Down
24 changes: 24 additions & 0 deletions lightning_ir/retrieve/faiss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from .faiss_indexer import (
FaissFlatIndexConfig,
FaissFlatIndexer,
FaissIVFIndexConfig,
FaissIVFIndexer,
FaissIVFPQIndexConfig,
FaissIVFPQIndexer,
FaissPQIndexConfig,
FaissPQIndexer,
)
from .faiss_searcher import FaissSearchConfig, FaissSearcher

__all__ = [
"FaissFlatIndexConfig",
"FaissFlatIndexer",
"FaissIVFIndexConfig",
"FaissIVFIndexer",
"FaissIVFPQIndexConfig",
"FaissIVFPQIndexer",
"FaissPQIndexConfig",
"FaissPQIndexer",
"FaissSearchConfig",
"FaissSearcher",
]
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import torch

from ..bi_encoder import BiEncoderConfig, BiEncoderOutput
from ..data import IndexBatch
from .indexer import IndexConfig, Indexer
from ...bi_encoder import BiEncoderConfig, BiEncoderOutput
from ...data import IndexBatch
from ..base import IndexConfig, Indexer


class FaissIndexer(Indexer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import torch

from ..bi_encoder.model import BiEncoderEmbedding
from .searcher import SearchConfig, Searcher
from ...bi_encoder.model import BiEncoderEmbedding
from ..base import SearchConfig, Searcher

if TYPE_CHECKING:
from ..bi_encoder import BiEncoderModule
from ...bi_encoder import BiEncoderModule


class FaissSearcher(Searcher):
Expand Down
2 changes: 2 additions & 0 deletions lightning_ir/retrieve/sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .sparse_indexer import SparseIndexConfig, SparseIndexer
from .sparse_searcher import SparseSearchConfig, SparseSearcher
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import torch

from ..bi_encoder import BiEncoderConfig, BiEncoderOutput
from ..data import IndexBatch
from .indexer import IndexConfig, Indexer
from ...bi_encoder import BiEncoderConfig, BiEncoderOutput
from ...data import IndexBatch
from ..base import IndexConfig, Indexer


class SparseIndexer(Indexer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import torch

from .searcher import SearchConfig, Searcher
from ..base import SearchConfig, Searcher
from .sparse_indexer import SparseIndexConfig

if TYPE_CHECKING:
from ..bi_encoder import BiEncoderEmbedding, BiEncoderModule
from ...bi_encoder import BiEncoderEmbedding, BiEncoderModule


class SparseIndex:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
from lightning_ir.retrieve import (
FaissFlatIndexConfig,
FaissSearchConfig,
IndexConfig,
SearchConfig,
SparseIndexConfig,
SparseSearchConfig,
)
from lightning_ir.retrieve.indexer import IndexConfig

from .conftest import CORPUS_DIR, DATA_DIR

Expand Down

0 comments on commit 7f21aef

Please sign in to comment.