Skip to content
This repository has been archived by the owner on Nov 5, 2024. It is now read-only.

Commit

Permalink
working?
Browse files Browse the repository at this point in the history
  • Loading branch information
hk21702 committed Apr 11, 2024
1 parent 2cbb238 commit f241fd4
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 44 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ physionet.org
wget-log
data
__pycache__
scratch.ipynb
scratch.ipynb
wandb
51 changes: 34 additions & 17 deletions MIMICDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,63 @@
from torch.utils.data import Dataset
import ast


class MimicDataset(Dataset):
"""Mimic Dataset - a dataset of diagnostic reports and their corresponding icd9 labels"""

def __init__(self, dataset, labels, csv_file_labels):
def __init__(
self,
text,
labels,
classes: list,
class2id: dict | None = None,
id2class: dict | None = None,
):
"""
Arguments:
csv_file (string): Path to the csv file with data and annotations
"""
self.dataset = dataset
self.text = text

# create a dictionary of icd codes and their corresponding index in the list
self.labels = labels
self.icd_labels = csv_file_labels
self.icd_labels = self.icd_labels["icd_code"].tolist()
self.icd_labels_dict = {self.icd_labels[i]: i for i in range(len(self.icd_labels))}

self.icd_labels = classes

self.class2id = (
{class_: id for id, class_ in enumerate(classes)}
if class2id is None
else class2id
)

self.id2class = (
{id: class_ for class_, id in self.class2id.items()}
if id2class is None
else id2class
)

self.icd_size = len(self.icd_labels)


def __len__(self):
return len(self.dataset)
return len(self.text)

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()

# get the tokenized data
item = self.dataset[idx]
item = self.text[idx]

icd_codes = ast.literal_eval(self.labels[idx])

# create the label tensor
label_tensor = torch.zeros(self.icd_size)
for code in icd_codes:
label_tensor[self.icd_labels_dict[code]] = 1
label_tensor[self.class2id[code]] = 1.0

item["labels"] = label_tensor

# squeeze item to remove the extra dimension
item["input_ids"] = item["input_ids"].squeeze()
item["attention_mask"] = item["attention_mask"].squeeze()
return item

return item
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,7 @@ transformers~=4.39.3
wandb~=0.16.5
scikit-learn ~=1.4.2
accelerate
datasets
datasets
evaluate
swifter
flash-attn
93 changes: 68 additions & 25 deletions training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import argparse
from pprint import pprint

import datasets
import evaluate
import numpy as np
import pandas as pd
import torch
Expand All @@ -11,41 +11,58 @@
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from transformers import (
AutoTokenizer,
AutoConfig,
EvalPrediction,
OPTForSequenceClassification,
Trainer,
TrainingArguments,
DataCollatorWithPadding,
)

MODEL = "facebook/opt-350m"

def train(args: argparse.Namespace):
# for some reason the trainer has issues passing parameters to the model_init function so this variable needs to be global
global num_labels
global tokenizer
datasets.logging.set_verbosity_info()
global classes
global class2id
global id2class
global clf_metrics

global sigmoid

print("Loading datasets")
data_files = {
"train": args.train_path,
"validation": args.val_path,
"test": args.test_path,
}
# train_dataset = pd.read_csv(args.train_dataset)
# val_dataset = pd.read_csv(args.val_dataset)
# test_dataset = pd.read_csv(args.test_dataset)

code_labels = pd.read_csv(args.code_labels)
"""train_set = pd.read_csv(args.train_path)
val_set = pd.read_csv(args.val_path)
test_set = pd.read_csv(args.test_path)"""

dataset = load_dataset("csv", data_files=data_files, cache_dir=args.cache_dir)

# Create class dictionaries
classes = [class_ for class_ in code_labels["icd_code"] if class_]
class2id = {class_: id for id, class_ in enumerate(classes)}
id2class = {id: class_ for class_, id in class2id.items()}

clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall", "roc_auc"])

print("Tokenizing datasets. Loading from cache if available.")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-6.7b", use_fast=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True)

torch.cuda.set_device(0)
torch.cuda.current_device()
sigmoid = torch.nn.Sigmoid()

# Run dummy tokenization run first to circumvent bug with hashing changing
# tokenizer("Some", "test")

dataset = dataset.map(tokenize, batched=True, load_from_cache_file=True, num_proc=8)
dataset = dataset.map(tokenize, load_from_cache_file=True, batched=True, num_proc=8)

# make sure to update num_train_epochs for actual training
# note - save stratedy and evaluation strategy need to match
Expand All @@ -56,31 +73,30 @@ def train(args: argparse.Namespace):
save_strategy="epoch",
save_steps=args.save_interval,
learning_rate=2e-5,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
num_train_epochs=1,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model="f1",
)

num_labels = len(code_labels)
pprint(f"num_labels: {num_labels}")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

pprint(f"num_labels: {len(code_labels)}")

if args.wandb_key:
pprint("Using wandb")
wandb.login(key=args.wandb_key)
training_args.report_to = "wandb"
training_args.report_to = ["wandb"]

if args.fresh_start:
pprint("Fresh Start")
hyperparameter_search(
trainer = hyperparameter_search(
model_init,
training_args,
dataset,
tokenizer,
compute_metrics,
n_trials=10,
data_collator,
n_trials=2,
)

else:
Expand All @@ -94,6 +110,7 @@ def train(args: argparse.Namespace):
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
tokenizer=tokenizer,
data_collator=data_collator,
)

trainer.train(resume_from_checkpoint=True)
Expand All @@ -102,18 +119,30 @@ def train(args: argparse.Namespace):
trainer.save_state()


def multi_labels_to_ids(labels: list[str]) -> list[float]:
ids = [0.0] * len(class2id) # BCELoss requires float as target type
for label in labels:
ids[class2id[label]] = 1.0
return ids


def tokenize(example):
return tokenizer(
result = tokenizer(
example["text"],
add_special_tokens=True,
)
result["label"] = [multi_labels_to_ids(eval(label)) for label in example["label"]]

return result


def model_init():
"""Model init for use for hyperparameter_search"""
return OPTForSequenceClassification.from_pretrained(
"facebook/opt-350m",
num_labels=num_labels,
MODEL,
num_labels=len(classes),
id2label=id2class,
label2id=class2id,
problem_type="multi_label_classification",
)

Expand All @@ -132,7 +161,7 @@ def wandb_hp_space(trial):
"metric": {"name": "loss", "goal": "minimize"},
"parameters": {
"learning_rate": {"distribution": "uniform", "min": 1e-6, "max": 1e-4},
"per_device_train_batch_size": {"values": [16, 32, 64, 128]},
"per_device_train_batch_size": {"values": [16, 32, 64]},
},
}
"""
Expand All @@ -141,7 +170,7 @@ def wandb_hp_space(trial):
"metric": {"name": "loss", "goal": "minimize"},
"parameters": {
"learning_rate": {"distribution": "uniform", "min": 1e-6, "max": 1e-4},
"per_device_train_batch_size": {"values": [16, 32, 64, 128]},
"per_device_train_batch_size": {"values": [16, 32, 64]},
},
}

Expand All @@ -152,6 +181,7 @@ def hyperparameter_search(
dataset,
tokenizer,
compute_metrics,
data_collator,
n_trials: int = 10,
):
"""
Expand All @@ -168,7 +198,7 @@ def hyperparameter_search(
n_trials (int, optional): The number of hyperparameter search trials. Defaults to 10.
Returns:
BestRun: The best run from the hyperparameter search.
Trainer: Trainer with attributes of best run
"""
pprint(f"Doing hyperparameter search with {n_trials} trials")
Expand All @@ -180,6 +210,7 @@ def hyperparameter_search(
eval_dataset=dataset["validation"],
tokenizer=tokenizer,
compute_metrics=compute_metrics,
data_collator=data_collator,
)

best_run = trainer.hyperparameter_search(
Expand All @@ -188,7 +219,12 @@ def hyperparameter_search(
direction="minimize",
backend="wandb",
)
return best_run

pprint(best_run)

for n, v in best_run.hyperparameters.items():
setattr(trainer.args, n, v)
return trainer


def multi_label_metrics(predictions, labels, threshold=0.5):
Expand All @@ -210,9 +246,16 @@ def multi_label_metrics(predictions, labels, threshold=0.5):


def compute_metrics(p: EvalPrediction):
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
"""preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
result = multi_label_metrics(predictions=preds, labels=p.label_ids)
return result
return result"""

predictions, labels = p
predictions = sigmoid(predictions)
predictions = (predictions > 0.5).astype(int).reshape(-1)
return clf_metrics.compute(
predictions=predictions, references=labels.astype(int).reshape(-1)
)


# TODO: leaving this code here for now but we should remove it before final submission
Expand Down

0 comments on commit f241fd4

Please sign in to comment.