-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
39 lines (32 loc) · 1.15 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import numpy as np
import torch
import cv2
from timm import create_model
from torchvision import transforms
from sklearn.model_selection import train_test_split, LeaveOneOut
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm # For progress bars
import torch
from timm import create_model
import timm
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load images and extract class codes from file names
def load_images_and_labels(folder):
images = []
labels = []
filenames = []
for filename in os.listdir(folder):
if filename.endswith('.jpg') or filename.endswith('.png'):
image = cv2.imread(os.path.join(folder, filename))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
images.append(image)
label = filename.split("_")[0]
labels.append(label)
filenames.append(filename)
return images, labels
images, labels = load_images_and_labels("/workspaces/gorilla_watch/video_data/bristol/cropped_frames_filtered")
print(len(images))
print(labels)