diff --git a/parameters.yaml b/parameters.yaml index 04d9b22..013fc26 100644 --- a/parameters.yaml +++ b/parameters.yaml @@ -3,8 +3,7 @@ LEARNING_RATE_2: 0.05 LEARNING_RATE_3: 0.05 LEARNING_RATE_4: 0.05 LEARNING_RATE_5: 0.03 -NUMBER_OF_CLASSES: 5 -EPOCHS_1: 5 +EPOCHS_1: 1 EPOCHS_2: 5 EPOCHS_3: 5 EPOCHS_4: 5 diff --git a/src/swahiliNewsClassifier/entity/entities.py b/src/swahiliNewsClassifier/entity/entities.py index db86462..14579a6 100644 --- a/src/swahiliNewsClassifier/entity/entities.py +++ b/src/swahiliNewsClassifier/entity/entities.py @@ -56,9 +56,7 @@ class ModelTrainingAndEvaluationConfig: epochs_5 (int): Number of epochs for the fourth phase of classifier training. This defines the number of complete passes through the training dataset in the final phase. training_data (Path): Path to the training data CSV file. This file contains the text data and corresponding labels for training and validation. - - number_of_classes (int): Number of target classes in the classification task. This defines the number of unique labels in the dataset. - + root_dir (Path): Root directory for storing model artifacts. This directory is used to save trained models, logs, and other artifacts. mlflow_tracking_uri (str): URI for the MLflow tracking server. This is used to log and track experiments with MLflow. @@ -83,7 +81,6 @@ class ModelTrainingAndEvaluationConfig: epochs_4: int epochs_5: int training_data: Path - number_of_classes: int root_dir: Path mlflow_tracking_uri: str mlflow_repo_name: str