diff --git a/lightning_ir/__init__.py b/lightning_ir/__init__.py index fbb5176..14fe378 100644 --- a/lightning_ir/__init__.py +++ b/lightning_ir/__init__.py @@ -68,7 +68,7 @@ RankNet, SupervisedMarginMSE, ) -from .main import LightningIRTrainer +from .main import LightningIRTrainer, LightningIRWandbLogger from .models import ( ColConfig, ColModel, @@ -177,6 +177,7 @@ "LightningIRTokenizer", "LightningIRTokenizerClassFactory", "LightningIRTrainer", + "LightningIRWandbLogger", "LinearLRSchedulerWithLinearWarmup", "LocalizedContrastiveEstimation", "LR_SCHEDULERS",