-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft document classifier #8
base: main
Are you sure you want to change the base?
Conversation
fef60cd
to
3127e5a
Compare
This is now rebased on main where #6 has been merged. |
6ecde1a
to
a34fe53
Compare
All dependencies such as the research libraries and code are now merged in |
print(f"Number of dropped empty texts: {empty_count} ({100 * empty_count / len(df_input):.1f}%)") | ||
return df.loc[~empty_index] | ||
|
||
def create_embeddings(column_to_embed: str, cache_directory: pathlib.Path = REPOSITORY_ROOT / "data" / "embeddings-cache"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function relies on global variables (df_input
and embedding_model
). Instead of using global state, let's pass the necessary data in as arguments. Something like this:
def create_embeddings(column_to_embed: str, cache_directory: pathlib.Path = REPOSITORY_ROOT / "data" / "embeddings-cache"): | |
def create_embeddings(strings_to_embed: pd.Series, embedding_model: embeddings.EmbeddingModel, cache_directory: pathlib.Path = REPOSITORY_ROOT / "data" / "embeddings-cache") -> np.ndarray: |
print(embeddings_doc_content_plain.shape) | ||
return embeddings_doc_content_plain | ||
|
||
def save_model(model_file_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is similar - let's pass classifier
in as an argument.
from research.lib import data_access, embeddings | ||
|
||
REPOSITORY_ROOT = (pathlib.Path().cwd() / ".." / "..").resolve() | ||
sys.path.append(str(REPOSITORY_ROOT)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary? Since the imports above (from research.lib ...
) apparently work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: I had to run the script like this:
cd research/document_types
PYTHONPATH=../../ uv run draft_classification.py
so that a) the research.lib
imports on line 22 work and b) the relative path to draft_classification.toml
works.
REPOSITORY_ROOT = (pathlib.Path().cwd() / ".." / "..").resolve() | ||
sys.path.append(str(REPOSITORY_ROOT)) | ||
|
||
dotenv.load_dotenv() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Side effects such as this should be in the if __name__ == "__main__":
block, otherwise they're triggered even when this module is imported by something else.
with open('draft_classification.toml', 'r') as f: | ||
config = toml.load(f) | ||
|
||
os.environ["MLFLOW_TRACKING_URI"] = config["tracking"]["tracking_uri"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also a side effect that should not be triggered at import time. In addition, we can use https://mlflow.org/docs/1.22.0/python_api/mlflow.html#mlflow.set_tracking_uri instead of manipulating environment variables. (I know it was like that in the original notebook but we can improve things in this new script 🙂)
with open('draft_classification.toml', 'r') as f: | ||
config = toml.load(f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with open('draft_classification.toml', 'r') as f: | |
config = toml.load(f) | |
config = toml.load("draft_classification.toml") |
However, it's also a side effect that would preferably happen in if __name__ == "__main__"
...
### Preprocessing ### | ||
df_input = remove_rows_with_missing_text(df_input) | ||
# set target variable | ||
df_input.loc[:, "is_draft"] = (df_input.loc[:, "document_type"] == "DRAFT").astype(bool) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oddly, I'm getting this warning on this line. I don't know why.
SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
… topic classification
I've placed two somewhat unrelated things into this branch because they both depend on #6 and on some research-y code from our old, exploratory repository.
research/consultation_topics/label_exploration.ipynb
takes a look at the consultation topic labels we have.research/document_types/VM_draft_classifier.ipynb
is a bare-bone document type classifier, on Fedlex data only, making "is DRAFT/is not a DRAFT" predictions only.We can split these up again if this branch stops making sense.
Handing this over to you now, @orieger 🙂