Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Feature: Using USFM as a feature extractor #29

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions tools/extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
Author: Zehui Lin
Date: 2024-12-11 16:09:53
LastEditors: Zehui Lin
LastEditTime: 2024-12-11 16:19:54
FilePath: /USFM/tools/extractor.py
Description: 写点东西描述一下这个文件叭~

Using USFM as a feature extractor

"""

import os

import cv2
import numpy as np
import torch
from torch.autograd import Variable
from tqdm import tqdm


def prepare_inputs():

dataset_path = "datasets/Seg/toy_dataset/training_set/image"
dataset_list = os.listdir(dataset_path)

input_list = []
bar = tqdm(total=len(dataset_list), ncols=100, desc="Prepare Inputs")
for filename in dataset_list:
bar.update(1)

image = cv2.imread(os.path.join(dataset_path, filename))
image = np.mean(image, axis=2)

image = cv2.resize(image, (224, 224))
image = image / 255.0
image = (image - image.mean()) / image.std()
image_data = np.repeat(image[:, :, None], 3, axis=-1)

input_list.append(image_data)

return input_list


def main():

import logging

import yaml
from omegaconf import OmegaConf

from usdsgen.modules.backbone.vision_transformer import build_vit

with open("configs/model/Cls/vit.yaml") as f:
cfg = OmegaConf.create(yaml.load(f, Loader=yaml.FullLoader))
cfg.model.model_cfg.num_classes = 6
cfg.model.model_cfg.backbone.pretrained = "./assets/FMweight/USFM_latest.pth"

logger = logging.getLogger("logger")
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())

model = build_vit(cfg.model.model_cfg, logger)
model.to("cuda:0")
model.eval()

input_list = prepare_inputs()
feature_list = []
bar = tqdm(total=len(input_list), ncols=100, desc="Extract Features")
for image in input_list:
bar.update(1)
x = Variable(torch.from_numpy(image).float().cuda())
x = x.permute(2, 0, 1).unsqueeze(0)
feature = model.forward_features(x) # 1, 768
feature = feature.cpu().data.numpy().flatten()
feature_list.append(feature)

feature_list = np.array(feature_list)
np.save(
"tools/saved_feature.npy",
feature_list,
)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion usdsgen/modules/backbone/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def make(**config):

def build_vit(model_cfg, logger):
model = VisionTransformer(norm_layer=partial(nn.LayerNorm, eps=1e-6), **model_cfg)
if model_cfg.pretrained:
if model_cfg.backbone.pretrained:
load_pretrained(model_cfg, model, logger)
return model

Expand Down
8 changes: 5 additions & 3 deletions usdsgen/utils/modelutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ def reduce_tensor(tensor):


def load_pretrained(model_cfg, model, logger):
logger.info(f">>>>>>>>>> Fine-tuned from {model_cfg.pretrained} ..........")
checkpoint = torch.load(model_cfg.pretrained, map_location="cpu")
logger.info(
f">>>>>>>>>> Fine-tuned from {model_cfg.backbone.pretrained} .........."
)
checkpoint = torch.load(model_cfg.backbone.pretrained, map_location="cpu")
checkpoint_model = checkpoint

if "swin" in model_cfg.type.lower():
Expand All @@ -97,7 +99,7 @@ def load_pretrained(model_cfg, model, logger):

del checkpoint_model
torch.cuda.empty_cache()
logger.info(f">>>>>>>>>> loaded successfully '{model_cfg.pretrained}'")
logger.info(f">>>>>>>>>> loaded successfully '{model_cfg.backbone.pretrained}'")


def remap_pretrained_keys_swin(model, checkpoint_model, logger):
Expand Down