Skip to content

Commit

Permalink
removed hard coded class number
Browse files Browse the repository at this point in the history
Signed-off-by: elitap <[email protected]>
  • Loading branch information
elitap committed Aug 23, 2023
1 parent 1e7f3cc commit 6b2da77
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
3 changes: 2 additions & 1 deletion training/example_train_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ python main_2pt5d.py --max_epochs 100 --val_every 1 --optim_lr 0.000005 \
--logdir finetune_ckpt_example --point_prompt --label_prompt --distributed --seed 12346 \
--iterative_training_warm_up_epoch 50 --reuse_img_embedding \
--label_prompt_warm_up_epoch 25 \
--checkpoint ./runs/9s_2dembed_model.pt
--checkpoint ./runs/9s_2dembed_model.pt \
--num_classes 105
11 changes: 10 additions & 1 deletion training/main_2pt5d.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@
parser.add_argument("--skip_bk", action="store_true", help="skip background (0) during training")
parser.add_argument("--patch_embed_3d", action="store_true", help="using 3d patch embedding layer")

parser.add_argument("--num_classes", default=105, type=int, help="number of output classes")


def start_tb(log_dir):
cmd = ["tensorboard", "--logdir", log_dir]
Expand All @@ -123,6 +125,10 @@ def main():
args = parser.parse_args()
args.amp = not args.noamp
args.logdir = "./runs/" + args.logdir

if args.num_classes == 0:
warnings.warn("consider setting the correct number of classes")

# start_tb(args.logdir)
if args.seed > -1:
set_determinism(seed=args.seed)
Expand All @@ -135,6 +141,9 @@ def main():
main_worker(gpu=0, args=args)





def main_worker(gpu, args):
if args.distributed:
torch.multiprocessing.set_start_method("fork", force=True)
Expand Down Expand Up @@ -162,7 +171,7 @@ def main_worker(gpu, args):

dice_loss = DiceCELoss(sigmoid=True)

post_label = AsDiscrete(to_onehot=105)
post_label = AsDiscrete(to_onehot=args.num_classes)
post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
dice_acc = DiceMetric(include_background=False, reduction=MetricReduction.MEAN, get_not_nans=True)

Expand Down
8 changes: 4 additions & 4 deletions training/trainer_2pt5d.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def prepare_sam_training_input(inputs, labels, args, model):
unique_labels = unique_labels[: args.num_prompt]

# add 4 background labels to every batch
background_labels = list(set([i for i in range(1, 105)]) - set(unique_labels.cpu().numpy()))
background_labels = list(set([i for i in range(1, args.num_classes)]) - set(unique_labels.cpu().numpy()))
random.shuffle(background_labels)
unique_labels = torch.cat([unique_labels, torch.tensor(background_labels[:4]).cuda(args.rank)])

Expand Down Expand Up @@ -375,7 +375,7 @@ def train_epoch_iterative(model, loader, optimizer, scaler, epoch, loss_func, ar


def prepare_sam_test_input(inputs, labels, args, previous_pred=None):
unique_labels = torch.tensor([i for i in range(1, 105)]).cuda(args.rank)
unique_labels = torch.tensor([i for i in range(1, args.num_classes)]).cuda(args.rank)

# preprocess make the size of lable same as high_res_logit
batch_labels = torch.stack([labels == unique_labels[i] for i in range(len(unique_labels))], dim=0).float()
Expand All @@ -400,7 +400,7 @@ def prepare_sam_test_input(inputs, labels, args, previous_pred=None):

def prepare_sam_val_input_cp_only(inputs, labels, args):
# Don't exclude background in val but will ignore it in metric calculation
unique_labels = torch.tensor([i for i in range(1, 105)]).cuda(args.rank)
unique_labels = torch.tensor([i for i in range(1, args.num_classes)]).cuda(args.rank)

# preprocess make the size of lable same as high_res_logit
batch_labels = torch.stack([labels == unique_labels[i] for i in range(len(unique_labels))], dim=0).float()
Expand Down Expand Up @@ -460,7 +460,7 @@ def val_epoch(model, loader, epoch, acc_func, args, iterative=False, post_label=
acc_batch = compute_dice(y_pred=y_pred, y=target)
acc_sum, not_nans = (
torch.nansum(acc_batch).item(),
104 - torch.sum(torch.isnan(acc_batch).float()).item(),
(args.num_classes-1) - torch.sum(torch.isnan(acc_batch).float()).item(),
)
acc_sum_total += acc_sum
not_nans_total += not_nans
Expand Down

0 comments on commit 6b2da77

Please sign in to comment.