Skip to content

Commit

Permalink
update notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Feb 23, 2024
1 parent 2d4c15e commit fc3e75c
Showing 1 changed file with 10 additions and 75 deletions.
85 changes: 10 additions & 75 deletions examples/advanced/bionemo/task_fitting/task_fitting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
"source": [
"# Federated Protein Embeddings and Task Model Fitting with BioNeMo\n",
"\n",
"This example notebook shows how to obtain protein learned representations in the form of embeddings using the ESM-1nv pre-trained model. The model is trained with NVIDIA's BioNeMo framework for Large Language Model training and inference. For more details, please visit NVIDIA BioNeMo Service at https://www.nvidia.com/en-us/gpu-cloud/bionemo.\n",
"This example notebook shows how to obtain protein learned representations in the form of embeddings using the ESM-1nv pre-trained model in a federated learning (FL) setting. The model is trained with NVIDIA's BioNeMo framework for Large Language Model training and inference. For more details, please visit NVIDIA BioNeMo Service at https://www.nvidia.com/en-us/gpu-cloud/bionemo.\n",
"\n",
"This example is based on NVIDIA BioNeMo Service [example](https://github.com/NVIDIA/BioNeMo/blob/main/examples/service/notebooks/task-fitting-predictor.ipynb) \n",
"but runs inference locally (on the FL clients) instead of using BioNeMo's cloud API.\n",
"\n",
"This notebook will walk you through the task fitting workflow in the following sections:\n",
"\n",
Expand All @@ -22,19 +25,7 @@
"metadata": {},
"source": [
"### Install requirements\n",
"Let's start by installing and importing library dependencies. We'll use requests to interact with the BioNeMo service, BioPython to parse FASTA sequences into SeqRecord objects, scikit-learn for classification tasks, and matplotlib for visualization."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76e2fb1b",
"metadata": {},
"outputs": [],
"source": [
"#!pip install -r requirements.txt\n",
"#!pip install -e /home/hroth/Code/nvflare/bionemo_nvflare\n",
"#!pip install biopython scikit-learn matplotlib"
"Please follow the instructions [here](./README.md) before running the notebook."
]
},
{
Expand All @@ -43,7 +34,7 @@
"metadata": {},
"source": [
"### Obtaining the protein embeddings using the BioNeMo ESM-1nv model\n",
"Using BioNeMo, users can obtain numerical vector representations of protein sequences called embeddings. Protein embeddings can then be used for visualization or making downstream predictions.\n",
"Using BioNeMo, each FL client can obtain numerical vector representations of protein sequences called embeddings. Protein embeddings can then be used for visualization or making downstream predictions.\n",
"\n",
"Here we are interested in training a neural network to predict subcellular location from an embedding.\n",
"\n",
Expand Down Expand Up @@ -205,7 +196,7 @@
"metadata": {},
"source": [
"### Inspecting the embeddings and labels\n",
"Embeddings returned from the BioNeMo server are vectors of fixed size for each input sequence. In other words, if we input 10 sequences, we will obtain a matrix `10xD`, where `D` is the size of the embedding (in the case of ESM-1nv, `D=768`). At a glance, these real-valued vector embeddings don't show any obvious features (see the printout in the next cell). But these vectors do contain information that can be used in downstream models to reveal properties of the protein, for example the subcellular location as we'll explore below."
"Embeddings returned from the BioNeMo model are vectors of fixed size for each input sequence. In other words, if we input 10 sequences, we will obtain a matrix `10xD`, where `D` is the size of the embedding (in the case of ESM-1nv, `D=768`). At a glance, these real-valued vector embeddings don't show any obvious features (see the printout in the next cell). But these vectors do contain information that can be used in downstream models to reveal properties of the protein, for example the subcellular location as we'll explore below."
]
},
{
Expand Down Expand Up @@ -256,7 +247,9 @@
"metadata": {},
"source": [
"### Training a MLP to predict subcellular location\n",
"To be able to classify proteins for their subcellular location, we train a simple scikit-learn Multi-layer Perceptron (MPL) classifier. The MLP model uses a network of hidden layers to fit the input embedding vectors to the model classes (the cellular locations above). In the call below, we define the MLP to use the Adam optimizer with a network of 32 hidden layers, defining a random state (or seed) for reproducibility, and trained for a maximum of 500 iterations.\n",
"To be able to classify proteins for their subcellular location, we train a simple scikit-learn Multi-layer Perceptron (MPL) classifier using Federated Averaging ([FedAvg](https://arxiv.org/abs/1602.05629)). The MLP model uses a network of hidden layers to fit the input embedding vectors to the model classes (the cellular locations above). In the simulation below, we define the MLP to use the Adam optimizer with a network of (512, 256, 128) hidden layers, defining a random state (or seed) for reproducibility, and trained for 30 rounds of FedAvg (see [config_fed_server.json](./jobs/fedavg/app/config/config_fed_server.json)). \n",
"\n",
"We can use the same configuration also to simulate local training where each client is only training with their own data by setting `os.environ[\"SIM_LOCAL\"] = \"True\"`. Our [BioNeMoMLPLearner](./jobs/fedavg/app/custom/bionemo_mlp_learner.py) will then ignore the global weights coming from the server.\n",
"\n",
"### Local training"
]
Expand Down Expand Up @@ -311,64 +304,6 @@
"run_status = simulator.run()\n",
"print(\"Simulator finished with run_status\", run_status)"
]
},
{
"cell_type": "markdown",
"id": "a68dfe9a",
"metadata": {},
"source": [
"## Finetuning ESM1nv\n",
"#### Federated Learning"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1fedaba",
"metadata": {},
"outputs": [],
"source": [
"# data preprocessing\n",
"#!python jobs/fedavg_finetune_esm1nv/app/custom/downstream_flip.py do_training=False"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "85816617",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# DEBUG\n",
"import os\n",
"os.environ[\"SIM_LOCAL\"] = \"False\"\n",
"from nvflare import SimulatorRunner \n",
"n_clients = 1\n",
"split_alpha = 100.0\n",
"\n",
"simulator = SimulatorRunner(\n",
" #job_folder=\"jobs/fedavg_finetune_esm1nv\",\n",
" #workspace=f\"/tmp/nvflare/bionemo/fedavg_finetune_esm1nv_alpha{split_alpha}\",\n",
" #workspace=f\"/tmp/nvflare/bionemo/local_site1_finetune_esm1nv_alpha{split_alpha}_unfreeze_encoder4\",\n",
" job_folder=\"jobs/fedavg_finetune_esm2nv\",\n",
" #workspace=f\"/tmp/nvflare/bionemo/local_site1_finetune_esm2nv_alpha{split_alpha}_freeze_encoder2\",\n",
" workspace=f\"/tmp/nvflare/bionemo/local_site1_finetune_esm2nv_alpha{split_alpha}_unfreeze_encoder1_large_ds\",\n",
" n_clients=n_clients,\n",
" threads=n_clients\n",
")\n",
"run_status = simulator.run()\n",
"print(\"Simulator finished with run_status\", run_status)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6ff0f8e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit fc3e75c

Please sign in to comment.