diff --git a/lightning_ir/__init__.py b/lightning_ir/__init__.py index 5ae6b5b..c454fe8 100644 --- a/lightning_ir/__init__.py +++ b/lightning_ir/__init__.py @@ -61,6 +61,7 @@ RankNet, SupervisedMarginMSE, ) +from .main import LightningIRTrainer from .models import ColConfig, ColModel, SpladeConfig, SpladeModel, XTRConfig, XTRModel from .retrieve import ( FaissFlatIndexConfig, @@ -146,6 +147,7 @@ "LightningIRModule", "LightningIROutput", "LightningIRTokenizer", + "LightningIRTrainer", "LinearLRSchedulerWithLinearWarmup", "LocalizedContrastiveEstimation", "LR_SCHEDULERS",