Skip to content

Commit

Permalink
split lightning_utils into schedulers and callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Dec 16, 2024
1 parent 17d4eb7 commit 044ab99
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 18 deletions.
22 changes: 10 additions & 12 deletions lightning_ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
BiEncoderTokenizer,
ScoringFunction,
)
from .callbacks import IndexCallback, RankCallback, RegisterLocalDatasetCallback, ReRankCallback, SearchCallback
from .cross_encoder import (
CrossEncoderConfig,
CrossEncoderModel,
Expand All @@ -42,18 +43,6 @@
TrainBatch,
TupleDataset,
)
from .lightning_utils import (
ConstantLRSchedulerWithLinearWarmup,
GenericConstantSchedulerWithLinearWarmup,
GenericConstantSchedulerWithQuadraticWarmup,
GenericLinearSchedulerWithLinearWarmup,
IndexCallback,
LinearLRSchedulerWithLinearWarmup,
RankCallback,
ReRankCallback,
SearchCallback,
WarmupLRScheduler,
)
from .loss import (
ApproxMRR,
ApproxNDCG,
Expand Down Expand Up @@ -102,6 +91,14 @@
SparseSearchConfig,
SparseSearcher,
)
from .schedulers import (
ConstantLRSchedulerWithLinearWarmup,
GenericConstantSchedulerWithLinearWarmup,
GenericConstantSchedulerWithQuadraticWarmup,
GenericLinearSchedulerWithLinearWarmup,
LinearLRSchedulerWithLinearWarmup,
WarmupLRScheduler,
)

AutoConfig.register(BiEncoderConfig.model_type, BiEncoderConfig)
AutoModel.register(BiEncoderConfig, BiEncoderModel)
Expand Down Expand Up @@ -188,6 +185,7 @@
"RankBatch",
"RankNet",
"RankSample",
"RegisterLocalDatasetCallback",
"ReRankCallback",
"RunDataset",
"ScoreBasedInBatchCrossEntropy",
Expand Down
9 changes: 9 additions & 0 deletions lightning_ir/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .callbacks import IndexCallback, RankCallback, RegisterLocalDatasetCallback, ReRankCallback, SearchCallback

__all__ = [
"IndexCallback",
"RankCallback",
"RegisterLocalDatasetCallback",
"ReRankCallback",
"SearchCallback",
]
File renamed without changes.
2 changes: 1 addition & 1 deletion lightning_ir/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing_extensions import override

import lightning_ir # noqa: F401
from lightning_ir.lightning_utils.lr_schedulers import LR_SCHEDULERS, WarmupLRScheduler
from lightning_ir.schedulers.lr_schedulers import LR_SCHEDULERS, WarmupLRScheduler

if torch.cuda.is_available():
torch.set_float32_matmul_precision("medium")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from .callbacks import IndexCallback, RankCallback, ReRankCallback, SearchCallback
"""
Module containing utility classes and functions for PyTorch Lightning.
This module provides callbacks .
"""

from .lr_schedulers import ConstantLRSchedulerWithLinearWarmup, LinearLRSchedulerWithLinearWarmup, WarmupLRScheduler
from .schedulers import (
GenericConstantSchedulerWithLinearWarmup,
Expand All @@ -11,10 +16,6 @@
"GenericConstantSchedulerWithLinearWarmup",
"GenericConstantSchedulerWithQuadraticWarmup",
"GenericLinearSchedulerWithLinearWarmup",
"IndexCallback",
"LinearLRSchedulerWithLinearWarmup",
"RankCallback",
"ReRankCallback",
"SearchCallback",
"WarmupLRScheduler",
]
File renamed without changes.
File renamed without changes.

0 comments on commit 044ab99

Please sign in to comment.