diff --git a/tools/extractor.py b/tools/extractor.py new file mode 100644 index 0000000..8477e3d --- /dev/null +++ b/tools/extractor.py @@ -0,0 +1,86 @@ +""" +Author: Zehui Lin +Date: 2024-12-11 16:09:53 +LastEditors: Zehui Lin +LastEditTime: 2024-12-11 16:19:54 +FilePath: /USFM/tools/extractor.py +Description: 写点东西描述一下这个文件叭~ + +Using USFM as a feature extractor + +""" + +import os + +import cv2 +import numpy as np +import torch +from torch.autograd import Variable +from tqdm import tqdm + + +def prepare_inputs(): + + dataset_path = "datasets/Seg/toy_dataset/training_set/image" + dataset_list = os.listdir(dataset_path) + + input_list = [] + bar = tqdm(total=len(dataset_list), ncols=100, desc="Prepare Inputs") + for filename in dataset_list: + bar.update(1) + + image = cv2.imread(os.path.join(dataset_path, filename)) + image = np.mean(image, axis=2) + + image = cv2.resize(image, (224, 224)) + image = image / 255.0 + image = (image - image.mean()) / image.std() + image_data = np.repeat(image[:, :, None], 3, axis=-1) + + input_list.append(image_data) + + return input_list + + +def main(): + + import logging + + import yaml + from omegaconf import OmegaConf + + from usdsgen.modules.backbone.vision_transformer import build_vit + + with open("configs/model/Cls/vit.yaml") as f: + cfg = OmegaConf.create(yaml.load(f, Loader=yaml.FullLoader)) + cfg.model.model_cfg.num_classes = 6 + cfg.model.model_cfg.backbone.pretrained = "./assets/FMweight/USFM_latest.pth" + + logger = logging.getLogger("logger") + logger.setLevel(logging.INFO) + logger.addHandler(logging.StreamHandler()) + + model = build_vit(cfg.model.model_cfg, logger) + model.to("cuda:0") + model.eval() + + input_list = prepare_inputs() + feature_list = [] + bar = tqdm(total=len(input_list), ncols=100, desc="Extract Features") + for image in input_list: + bar.update(1) + x = Variable(torch.from_numpy(image).float().cuda()) + x = x.permute(2, 0, 1).unsqueeze(0) + feature = model.forward_features(x) # 1, 768 + feature = feature.cpu().data.numpy().flatten() + feature_list.append(feature) + + feature_list = np.array(feature_list) + np.save( + "tools/saved_feature.npy", + feature_list, + ) + + +if __name__ == "__main__": + main() diff --git a/usdsgen/modules/backbone/vision_transformer.py b/usdsgen/modules/backbone/vision_transformer.py index 4692189..61faf5a 100644 --- a/usdsgen/modules/backbone/vision_transformer.py +++ b/usdsgen/modules/backbone/vision_transformer.py @@ -461,7 +461,7 @@ def make(**config): def build_vit(model_cfg, logger): model = VisionTransformer(norm_layer=partial(nn.LayerNorm, eps=1e-6), **model_cfg) - if model_cfg.pretrained: + if model_cfg.backbone.pretrained: load_pretrained(model_cfg, model, logger) return model diff --git a/usdsgen/utils/modelutils.py b/usdsgen/utils/modelutils.py index 63f50cc..eabdcc5 100644 --- a/usdsgen/utils/modelutils.py +++ b/usdsgen/utils/modelutils.py @@ -79,8 +79,10 @@ def reduce_tensor(tensor): def load_pretrained(model_cfg, model, logger): - logger.info(f">>>>>>>>>> Fine-tuned from {model_cfg.pretrained} ..........") - checkpoint = torch.load(model_cfg.pretrained, map_location="cpu") + logger.info( + f">>>>>>>>>> Fine-tuned from {model_cfg.backbone.pretrained} .........." + ) + checkpoint = torch.load(model_cfg.backbone.pretrained, map_location="cpu") checkpoint_model = checkpoint if "swin" in model_cfg.type.lower(): @@ -97,7 +99,7 @@ def load_pretrained(model_cfg, model, logger): del checkpoint_model torch.cuda.empty_cache() - logger.info(f">>>>>>>>>> loaded successfully '{model_cfg.pretrained}'") + logger.info(f">>>>>>>>>> loaded successfully '{model_cfg.backbone.pretrained}'") def remap_pretrained_keys_swin(model, checkpoint_model, logger):