Skip to content

Commit

Permalink
Debug.
Browse files Browse the repository at this point in the history
  • Loading branch information
zmgong committed Jul 13, 2024
1 parent 0eda268 commit fe91901
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 16 deletions.
9 changes: 5 additions & 4 deletions bioscanclip/epoch/train_epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


def train_epoch(activate_wandb, total_epochs, epoch, dataloader, model, optimizer, criterion, device, open_clip_ver=False, rank=None, check_cuda_memory=False):
if rank is not None and rank == 0:
pbar = tqdm(enumerate(dataloader), total=len(dataloader))
else:
pbar = enumerate(dataloader)
# if rank is not None and rank == 0:
# pbar = tqdm(enumerate(dataloader), total=len(dataloader))
# else:
# pbar = enumerate(dataloader)
pbar = tqdm(enumerate(dataloader), total=len(dataloader))
epoch_loss = 0.0
total_step = len(dataloader)

Expand Down
4 changes: 1 addition & 3 deletions bioscanclip/util/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,8 +681,6 @@ def species_list_to_labels(species_list, species_to_others):
return array_of_dicts

class INSECTDataset(Dataset):


def __init__(self, path_to_att_splits_mat, path_to_res_101_mat, image_hdf5_path, dna_transforms, species_to_others,
split, for_training=False, cl_label=False, **kwargs) -> None:
super().__init__()
Expand Down Expand Up @@ -819,7 +817,7 @@ def load_insect_dataloader(args, world_size=None, rank=None, num_workers=8, load
args.insect_data.path_to_att_splits_mat, args.insect_data.path_to_res_101_mat,
species_to_others=specie_to_other_labels, split="train_loc",
image_hdf5_path=args.insect_data.path_to_image_hdf5,
dna_transforms=sequence_pipeline, for_training=True
dna_transforms=sequence_pipeline, for_training=True, cl_label=True
)

train_dataset_for_key = INSECTDataset(
Expand Down
2 changes: 0 additions & 2 deletions scripts/train_cl.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ def main_process(rank: int, world_size: int, args):
args.save_inference = False
args.save_ckpt = False



current_datetime = datetime.datetime.now()
formatted_datetime = current_datetime.strftime("%Y-%m-%d_%H%M%S")
args = copy.deepcopy(args)
Expand Down
9 changes: 2 additions & 7 deletions scripts/train_cl_on_insect_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,10 @@ def convert_acc_dict_to_wandb_dict(acc_dict):

def main_process(rank: int, world_size: int, args):

# # Set up for debug, delete when you see it!

args.debug_flag = False
# # Set up for debug, delete when you see it!

# special set up for train on INSECT dataset
args.model_config.batch_size = 350
args.model_config.batch_size = 400
args.model_config.epochs = 80
args.model_config.evaluation_period = 75
args.model_config.evaluation_period = 40


if args.debug_flag:
Expand Down

0 comments on commit fe91901

Please sign in to comment.