Skip to content

Commit

Permalink
updated dataset configs
Browse files Browse the repository at this point in the history
  • Loading branch information
faizan1234567 committed Dec 19, 2024
1 parent a4b5cf1 commit d088a52
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 8 deletions.
3 changes: 2 additions & 1 deletion conf/dataset/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ irl_pc: /media/faizan/2418a65c-11bc-4f64-bf23-ecd60d93cf53/faizan/Faizan_thesis/
laptop_pc: "E:/Brats23 Data/Dataset/BraTS23_mapped/dataset"
colab: "/gdrive/MyDrive/BraTS2023/"
sines_pc: /drive/faizanai.rrl/datasets
dataset_folder: # specify your dataset path here if any of the above does not relate to you
dataset_folder: # specify your dataset path here if any of the above does not relate to you
version: brats2023
2 changes: 1 addition & 1 deletion conf/model/default.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
roi: 128
infer_overlap: 0.6
architecture: scfe_net
architecture: segres_net
2 changes: 1 addition & 1 deletion conf/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ solver_name: Adam
loss_type: dice
pretrained: False
resume: False
exp_name: scfe_net_runs
exp_name: segres_net_runs
seed: 123
colab: False
irl: True
Expand Down
9 changes: 5 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from thesis.models.v3.model import SCFENet
except ModuleNotFoundError:
print('model not available, please train with other models')
sys.exit(1)
# sys.exit(1)

from functools import partial

Expand Down Expand Up @@ -143,7 +143,7 @@ def test(args, data_loader, model):
def main(cfg: DictConfig):
# Select model

device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Efficient training
torch.backends.cudnn.benchmark = True
Expand Down Expand Up @@ -267,13 +267,14 @@ def main(cfg: DictConfig):
batch_size = cfg.test.batch
workers = cfg.test.workers
dataset_folder = cfg.dataset.irl_pc
dataset_version = cfg.dataset.version

# Load checkpoints
model.load_state_dict(torch.load(cfg.test.weights, weights_only=True))
model.load_state_dict(torch.load(cfg.test.weights, weights_only=True, map_location=device))
model.eval()

# Load dataset
test_loader = get_datasets(dataset_folder=dataset_folder, mode="test", target_size=(128, 128, 128))
test_loader = get_datasets(dataset_folder=dataset_folder, mode="test", target_size=(128, 128, 128), version=dataset_version)
test_loader = torch.utils.data.DataLoader(test_loader,
batch_size=batch_size,
shuffle=False, num_workers=workers,
Expand Down
2 changes: 1 addition & 1 deletion utils/all_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def save_seg_csv(csv, args):
columns = ['id', 'et_dice', 'tc_dice', 'wt_dice', 'et_hd', 'tc_hd', 'wt_hd', 'et_sens', 'tc_sens', 'wt_sens', 'et_spec', 'tc_spec', 'wt_spec']
save_path = os.path.join(args.training.exp_name, "csv")
os.makedirs(save_path, exist_ok= True)
csv_path = os.path.join(save_path, "test_metrics.csv")
csv_path = os.path.join(save_path, f"{args.model.architecture}_test_metrics.csv")
val_metrics.to_csv(csv_path, index=False, columns=columns)
except KeyboardInterrupt:
print("Save CSV File Error!")
Expand Down

0 comments on commit d088a52

Please sign in to comment.