Skip to content

Commit

Permalink
Fix training for all Camembert flavors
Browse files Browse the repository at this point in the history
  • Loading branch information
tomseimandi committed Mar 22, 2024
1 parent 72e0b56 commit 9797686
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
10 changes: 10 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
OneHotCategoricalCamembertWrapper,
EmbeddedCategoricalCamembertWrapper,
)
from camembert.camembert_model import (
CustomCamembertModel,
OneHotCategoricalCamembertModel,
EmbeddedCategoricalCamembertModel,
)


FRAMEWORK_CLASSES = {
Expand All @@ -28,29 +33,34 @@
"trainer": FastTextTrainer,
"evaluator": FastTextEvaluator,
"wrapper": FastTextWrapper,
"model": None,
},
"pytorch": {
"preprocessor": PytorchPreprocessor,
"trainer": PytorchTrainer,
"evaluator": PytorchEvaluator,
"wrapper": None,
"model": None,
},
"camembert": {
"preprocessor": CamembertPreprocessor,
"trainer": CustomCamembertTrainer,
"evaluator": CamembertEvaluator,
"wrapper": CustomCamembertWrapper,
"model": CustomCamembertModel,
},
"camembert_one_hot": {
"preprocessor": CamembertPreprocessor,
"trainer": OneHotCamembertTrainer,
"evaluator": CamembertEvaluator,
"wrapper": OneHotCategoricalCamembertWrapper,
"model": OneHotCategoricalCamembertModel,
},
"camembert_embedded": {
"preprocessor": CamembertPreprocessor,
"trainer": EmbeddedCamembertTrainer,
"evaluator": CamembertEvaluator,
"wrapper": EmbeddedCategoricalCamembertWrapper,
"model": EmbeddedCategoricalCamembertModel,
},
}
3 changes: 1 addition & 2 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from constants import FRAMEWORK_CLASSES
from utils.mappings import mappings
from camembert.camembert_model import CustomCamembertModel
from camembert.custom_pipeline import CustomPipeline
from tests.test_main import run_test
from utils.data import get_sirene_4_data, get_test_data, get_sirene_3_data
Expand Down Expand Up @@ -371,7 +370,7 @@ def main(

pipe = CustomPipeline(
framework="pt",
model=CustomCamembertModel.from_pretrained(
model=framework_classes["model"].from_pretrained(
model_output_dir,
num_labels=len(mappings.get("APE_NIV5")),
categorical_features=categorical_features,
Expand Down

0 comments on commit 9797686

Please sign in to comment.