From 6ecde1a936370d1ed1c19638668a9f4010dab0a4 Mon Sep 17 00:00:00 2001 From: Oliver Rieger Date: Thu, 5 Dec 2024 14:49:42 +0100 Subject: [PATCH] Adding a first version of the training script for a draft classifier --- pyproject.toml | 1 + .../document_types/draft_classification.py | 169 ++++++++++++++++++ .../document_types/draft_classification.toml | 27 +++ 3 files changed, 197 insertions(+) create mode 100644 research/document_types/draft_classification.py create mode 100644 research/document_types/draft_classification.toml diff --git a/pyproject.toml b/pyproject.toml index 466e08e..d491545 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "boto3>=1.35.69", "httpx>=0.27.2", "iterative-stratification>=0.1.9", + "jupyterlab>=4.3.1", "lingua-language-detector>=2.0.2", "matplotlib>=3.9.2", "mlflow>=2.18.0", diff --git a/research/document_types/draft_classification.py b/research/document_types/draft_classification.py new file mode 100644 index 0000000..2e38535 --- /dev/null +++ b/research/document_types/draft_classification.py @@ -0,0 +1,169 @@ +import logging +import pathlib +import sys +import dotenv +import os +import mlflow +import numpy as np +from sklearn.metrics import ( + classification_report, + confusion_matrix, + precision_recall_fscore_support +) +from sklearn.model_selection import train_test_split, StratifiedKFold, cross_validate +import pickle as pkl +import pandas as pd +from sklearn.linear_model import LogisticRegression +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm +import toml + +from research.lib import data_access, embeddings + +REPOSITORY_ROOT = (pathlib.Path().cwd() / ".." / "..").resolve() +sys.path.append(str(REPOSITORY_ROOT)) + +dotenv.load_dotenv() + +with open('draft_classification.toml', 'r') as f: + config = toml.load(f) + +os.environ["MLFLOW_TRACKING_URI"] = config["tracking"]["tracking_uri"] +DATA_FILE_NAME = config["data"]['data_file_name'] +PREPROCESSED_DATA_FILE = REPOSITORY_ROOT / "data" / "dataframes" / DATA_FILE_NAME +DOCUMENT_SOURCES = config["data"]['document_sources'] +LANGUAGES = config["data"]['languages'] +FROM_YEAR = config["data"]['from_year'] +DOC_TYPES = config["data"]["doc_types"] +EMBEDDING_MODEL = config["training"]['embedding_model'] +CV_FOLDS = config["training"]['cv_folds'] +TEST_SIZE = config["training"]['test_size'] +RANDOM_STATE = config["training"]['random_state'] + +def remove_rows_with_missing_text(df: pd.DataFrame) -> pd.DataFrame: + """ remove rows with missing text """ + empty_index = df["document_content_plain"] == "" + empty_count = len(df[empty_index]) + print(f"Number of dropped empty texts: {empty_count} ({100 * empty_count / len(df_input):.1f}%)") + return df.loc[~empty_index] + +def create_embeddings(column_to_embed: str, cache_directory: pathlib.Path = REPOSITORY_ROOT / "data" / "embeddings-cache"): + tokens = df_input[column_to_embed].progress_map(embedding_model.tokenize) + with embeddings.use_cache( + embedding_model, + tqdm=tqdm, + cache_directory=cache_directory, + ) as get_embeddings: + embeddings_doc_content_plain = get_embeddings(tokens.tolist()) + print(embeddings_doc_content_plain.shape) + return embeddings_doc_content_plain + +def save_model(model_file_name): + model_file_name.parent.mkdir(parents=True, exist_ok=True) + with open(model_file_name, "wb") as f: + pkl.dump(classifier, f) + mlflow.log_artifact(model_file_name) + +if __name__ == '__main__': + + ### Logging parameters ### + mlflow.set_experiment(experiment_name=f"Draft Classifier") + if run := mlflow.active_run(): + logging.warning("Run = %s is already active, closing it.", run.info.run_name) + mlflow.end_run() + run = mlflow.start_run() + print("Starting run:", run.info.run_name) + mlflow.log_param("input_file", DATA_FILE_NAME) + mlflow.log_param("document_sources", sorted(DOCUMENT_SOURCES)) + mlflow.log_param("languages", LANGUAGES) + mlflow.log_param("from_year", FROM_YEAR) + mlflow.log_param("doc_types", sorted(map(str, DOC_TYPES))) + mlflow.log_param("embedding_model", EMBEDDING_MODEL) + mlflow.log_param("cv_folds", CV_FOLDS) + mlflow.log_param("test_size", TEST_SIZE) + mlflow.log_param("random_state", RANDOM_STATE) + + mlflow.sklearn.autolog() + tqdm.pandas() + + #### Load data #### + data_access.ensure_dataframe_is_available(PREPROCESSED_DATA_FILE) + df_input = data_access.load_consultation_documents( + PREPROCESSED_DATA_FILE, + only_document_sources=DOCUMENT_SOURCES, + only_languages=LANGUAGES, + only_doc_types=DOC_TYPES, + starting_year=FROM_YEAR, + mlflow=mlflow, + ) + + ### Preprocessing ### + df_input = remove_rows_with_missing_text(df_input) + # set target variable + df_input.loc[:, "is_draft"] = (df_input.loc[:, "document_type"] == "DRAFT").astype(bool) + # create embeddings + embedding_model = embeddings.create_embedding_model(EMBEDDING_MODEL) + mlflow.log_param("embedding_model.max_input_tokens", embedding_model.max_input_tokens) + embeddings_doc_content_plain = create_embeddings(column_to_embed="document_content_plain") + + #### Train-Test Split #### + X = embeddings_doc_content_plain + y = df_input["is_draft"] + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE, random_state=RANDOM_STATE, stratify=y) + mlflow.log_param("train_samples_count", len(X_train)) + mlflow.log_param("test_samples_count", len(X_test)) + print("Train:", X_train.shape) + print("Test:", X_test.shape) + + #### Model #### + classifier = make_pipeline( + StandardScaler(), + # PCA(n_components=200, random_state=RANDOM_STATE), + LogisticRegression(max_iter=1000), + # SGDClassifier(loss="modified_huber", max_iter=1000), + # GradientBoostingClassifier(random_state=RANDOM_STATE), + # SVC(kernel="linear"), + ) + + #### CV #### + cv_splitter = StratifiedKFold( + n_splits=CV_FOLDS, + shuffle=True, + random_state=RANDOM_STATE, + ) + scores_docs = cross_validate( + classifier, + X=X_train, + y=y_train, + cv=cv_splitter, + scoring={ + "precision": "precision_weighted", + "recall": "recall_weighted", + "f1": "f1_weighted", + } + ) + avg_scores_docs = {k: np.round(np.mean(v), 4) for k, v in scores_docs.items()} + mlflow.log_metrics({f"{k}_docs": v for k, v in avg_scores_docs.items() if k.startswith("test_")}) + + print("Per-document CV scores:") + print(f"Precision: {avg_scores_docs['test_precision']:.4f} (+/- {np.std(scores_docs['test_precision']):.4f})") + print(f"Recall: {avg_scores_docs['test_recall']:.4f} (+/- {np.std(scores_docs['test_recall']):.4f})") + print(f"F1: {avg_scores_docs['test_f1']:.4f} (+/- {np.std(scores_docs['test_f1']):.4f})") + print(scores_docs) + + ## fit on the whole training set and evaluate on the test set ## + classifier.fit(X_train, y_train) + test_ground_truth_docs = y_test + test_predictions_docs = classifier.predict(X_test) + print("docs: ground truth", test_ground_truth_docs.shape) + print("docs: predictions", test_predictions_docs.shape) + + print(classification_report(test_ground_truth_docs, test_predictions_docs)) + cm = confusion_matrix(test_ground_truth_docs, test_predictions_docs, normalize="true") + print(cm) + precision_recall_fscore_support(test_ground_truth_docs, test_predictions_docs) + + ## save model ## + model_file = REPOSITORY_ROOT / "models" / "draft_classification.pkl" + save_model(model_file_name=model_file) diff --git a/research/document_types/draft_classification.toml b/research/document_types/draft_classification.toml new file mode 100644 index 0000000..701cbb0 --- /dev/null +++ b/research/document_types/draft_classification.toml @@ -0,0 +1,27 @@ +[data] +data_file_name = "consultation-documents-preprocessed-2024-11-26.parquet" +document_sources = ["fedlex"]#, "openparldata"] +languages = ["de"] # "fr", "it", ] +from_year = 2000 # 2010 +doc_types = [ + "LETTER", + "DRAFT", + "RECIPIENT_LIST", + "REPORT", + "FINAL_REPORT", + "OPINION", + "VARIOUS_TEXT", + "SYNOPTIC_TABLE", + "SURVEY", + "RESPONSE_FORM", + # None, # Many openparldata documents don't have a known type +] + +[training] +embedding_model = "openai/text-embedding-3-large" # "openai/text-embedding-ada-002" +cv_folds = 5 +test_size = 0.1 +random_state = 42 + +[tracking] +tracking_uri = "sqlite:///mlruns.db" \ No newline at end of file