-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
283 additions
and
148 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,154 +1,65 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import sys | ||
import argparse | ||
from tqdm import tqdm | ||
|
||
from torchvision import transforms | ||
from torch.utils.data import DataLoader, random_split | ||
from torch.utils.tensorboard import SummaryWriter | ||
from datetime import datetime | ||
from torchvision import datasets, transforms | ||
from lib.images import IMAGE_FOLDER | ||
from model import dataset, LeafClassifier, class_to_idx | ||
|
||
""" | ||
TODO: | ||
- visualize output | ||
""" | ||
# Set hyperparameters | ||
lr = 0.01 | ||
batch_size = 256 | ||
epochs = 10 | ||
|
||
def predict(val_loader): | ||
correct = 0 | ||
total = 0 | ||
|
||
# Load Dataset | ||
transform = transforms.Compose([ | ||
transforms.Resize((256, 256)), | ||
transforms.ToTensor(), | ||
# image = (image - mean) / std, range [-1, 1] | ||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | ||
]) | ||
|
||
dataset = datasets.ImageFolder(root=IMAGE_FOLDER, transform=transform) | ||
|
||
NUM_CLASSES = len(dataset.classes) | ||
# Get the mapping of class names to their corresponding labels | ||
class_to_idx = dataset.class_to_idx | ||
|
||
print(f"Number of classes: {NUM_CLASSES}") | ||
print(f"Class to index mapping: {class_to_idx}") | ||
print("\n", dataset) | ||
|
||
train_size = int(0.8 * len(dataset)) | ||
val_size = len(dataset) - train_size | ||
train_data, val_data = random_split(dataset, [train_size, val_size]) | ||
|
||
train_loader = DataLoader(train_data, batch_size, shuffle=True) | ||
val_loader = DataLoader(val_data, batch_size, shuffle=False) | ||
print(f"Dataset has been loaded - batch size:{batch_size}", end=" ") | ||
|
||
print(f"train_data: {len(train_loader)}, val_data: {len(val_loader)}") | ||
|
||
# Define model | ||
class LeafClassifier(nn.Module): | ||
""" | ||
Leaf Classifier that classifies the type of disease specified in the leaf | ||
Inputs: leaf Images, shape of (256, 256, 3) | ||
Labels: types of disease specified in the leaf | ||
""" | ||
def __init__(self): | ||
super(LeafClassifier, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 32, 3, 1) | ||
self.conv2 = nn.Conv2d(32, 64, 3, 1) | ||
self.conv3 = nn.Conv2d(64, 128, 3, 1) | ||
self.conv4 = nn.Conv2d(128, 256, 3, 1) | ||
|
||
self.pool = nn.MaxPool2d(2, 2) | ||
|
||
self.fc1 = nn.Linear(256 * 14 * 14, 512) | ||
self.fc2 = nn.Linear(512, NUM_CLASSES) | ||
|
||
self.dropout = nn.Dropout(0.5) | ||
|
||
def forward(self, x): | ||
x = torch.relu(self.conv1(x)) | ||
x = self.pool(x) | ||
x = torch.relu(self.conv2(x)) | ||
x = self.pool(x) | ||
x = torch.relu(self.conv3(x)) | ||
x = self.pool(x) | ||
x = torch.relu(self.conv4(x)) | ||
x = self.pool(x) | ||
|
||
# Flatten tensor | ||
x = x.view(-1, 256 * 14 * 14) | ||
|
||
x = torch.relu(self.fc1(x)) | ||
x = self.dropout(x) | ||
x = self.fc2(x) | ||
return x | ||
|
||
# Loss and Optimizer | ||
model = LeafClassifier() | ||
criterion = nn.CrossEntropyLoss() | ||
optimizer = optim.Adam(model.parameters(), lr=lr) | ||
|
||
# | ||
def train_one_epoch(epoch_index, tb_writer): | ||
running_loss = 0.0 | ||
last_loss = 0.0 | ||
|
||
for i, (inputs, labels) in enumerate(train_loader): | ||
optimizer.zero_grad() | ||
outputs = model(inputs) | ||
loss = criterion(outputs, labels) | ||
# print(inputs.shape, outputs.shape, labels.shape) | ||
loss.backward() | ||
optimizer.step() | ||
running_loss += loss.item() | ||
print(f"\tBatch [{i + 1}] - loss {running_loss}") | ||
if i % batch_size == batch_size - 1: | ||
last_loss = running_loss / batch_size | ||
print(f"batch {i + 1} loss: {last_loss}") | ||
tb_x = epoch_index * len(train_loader) + i + 1 | ||
tb_writer.add_scalar("Loss/train", last_loss, tb_x) | ||
running_loss = 0 | ||
|
||
last_loss = running_loss / len(train_loader) | ||
tb_x = epoch_index * len(train_loader) + i + 1 | ||
tb_writer.add_scalar("Loss/train", last_loss, tb_x) | ||
return last_loss | ||
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | ||
writer = SummaryWriter(f"./runs/leaf_trainer_{timestamp}") | ||
epoch_number = 0 | ||
best_vloss = 1_000_000. | ||
|
||
# | ||
print(f"\nStart training model - epochs:{epochs}") | ||
for epoch in range(epochs): | ||
print(f"Epoch [{epoch + 1}/{epochs}]") | ||
|
||
# train | ||
model.train(True) | ||
avg_loss = train_one_epoch(epoch_number, writer) | ||
|
||
# validation | ||
model.eval() | ||
running_vloss = 0.0 | ||
with torch.no_grad(): | ||
for i, (vinputs, vlabels) in enumerate(val_loader): | ||
for i, (vinputs, vlabels) in enumerate(tqdm(val_loader, desc="Validation Progress")): | ||
voutputs = model(vinputs) | ||
vloss = criterion(voutputs, vlabels) | ||
running_vloss += vloss | ||
avg_vloss = running_vloss / (i + 1) | ||
print(f"LOSS train {avg_loss:.4f} val {avg_vloss:.4f}") | ||
|
||
# Tensorboard | ||
writer.add_scalars("Training vs. Validation Loss", | ||
{"Training": avg_loss, "Validation": avg_vloss}, | ||
epoch_number + 1) | ||
writer.flush() | ||
|
||
if avg_vloss < best_vloss: | ||
best_vloss = avg_vloss | ||
torch.save(model.state_dict(), f"./model_{timestamp}_{epoch_number}") | ||
epoch_number += 1 | ||
_, predicted = torch.max(voutputs, 1) | ||
total += vlabels.size(0) | ||
correct += (predicted == vlabels).sum().item() | ||
|
||
accuracy = 100 * correct / total | ||
print(f'Accuracy of the model on the validation set: {accuracy:.2f}%') | ||
|
||
|
||
def main(file_path): | ||
try: | ||
if (file_path): | ||
train_size = int(0.8 * len(dataset)) | ||
val_size = len(dataset) - train_size | ||
gen = torch.Generator().manual_seed(42) | ||
_, val_data = random_split(dataset, [train_size, val_size], generator=gen) | ||
val_loader = DataLoader(val_data, batch_size=256, shuffle=False) | ||
predict(val_loader) | ||
else: | ||
print(f"Usage: {sys.argv[0]} [path_to_image]") | ||
except Exception as e: | ||
print(e) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="A program to classify a type of leaf from validation set." | ||
) | ||
|
||
parser.add_argument( | ||
"folder_path", | ||
type=str, | ||
nargs="?", | ||
help="Image folder path.", | ||
default="images" | ||
) | ||
|
||
args = parser.parse_args() | ||
model = LeafClassifier() | ||
model.load_state_dict(torch.load("./model_20240823_170439_4", weights_only=True)) | ||
model.eval() | ||
|
||
transform = transforms.Compose([ | ||
transforms.Resize((256, 256)), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.5,], std=[0.5,]) | ||
]) | ||
|
||
idx2class = {v: k for k, v in class_to_idx.items()} | ||
main(args.folder_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from torchvision import datasets, transforms | ||
from lib.images import IMAGE_FOLDER | ||
|
||
# Load Dataset | ||
transform = transforms.Compose([ | ||
transforms.Resize((256, 256)), | ||
transforms.ToTensor(), | ||
# image = (image - mean) / std, range [-1, 1] | ||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | ||
]) | ||
|
||
dataset = datasets.ImageFolder(root=IMAGE_FOLDER, transform=transform) | ||
|
||
NUM_CLASSES = len(dataset.classes) | ||
class_to_idx = dataset.class_to_idx | ||
|
||
# print(f"Number of classes: {NUM_CLASSES}") | ||
# print(f"Class to index mapping: {class_to_idx}") | ||
|
||
# Define model | ||
class LeafClassifier(nn.Module): | ||
""" | ||
Leaf Classifier that classifies the type of disease specified in the leaf | ||
Inputs: leaf Images, shape of (256, 256, 3) | ||
Labels: types of disease specified in the leaf | ||
""" | ||
def __init__(self): | ||
super(LeafClassifier, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 32, 3, 1) | ||
self.conv2 = nn.Conv2d(32, 64, 3, 1) | ||
self.conv3 = nn.Conv2d(64, 128, 3, 1) | ||
self.conv4 = nn.Conv2d(128, 256, 3, 1) | ||
|
||
self.pool = nn.MaxPool2d(2, 2) | ||
|
||
self.fc1 = nn.Linear(256 * 14 * 14, 512) | ||
self.fc2 = nn.Linear(512, NUM_CLASSES) | ||
|
||
self.dropout = nn.Dropout(0.5) | ||
|
||
def forward(self, x): | ||
x = torch.relu(self.conv1(x)) | ||
x = self.pool(x) | ||
x = torch.relu(self.conv2(x)) | ||
x = self.pool(x) | ||
x = torch.relu(self.conv3(x)) | ||
x = self.pool(x) | ||
x = torch.relu(self.conv4(x)) | ||
x = self.pool(x) | ||
|
||
# Flatten tensor | ||
x = x.view(-1, 256 * 14 * 14) | ||
|
||
x = torch.relu(self.fc1(x)) | ||
x = self.dropout(x) | ||
x = self.fc2(x) | ||
return x |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import torch | ||
import sys | ||
import argparse | ||
|
||
from PIL import Image | ||
from torchvision import transforms | ||
from model import LeafClassifier, class_to_idx | ||
|
||
def predict(file_path): | ||
test_image = Image.open(file_path) | ||
test_image = transform(test_image) | ||
test_image = test_image.unsqueeze(0) | ||
|
||
with torch.no_grad(): | ||
output = model(test_image) | ||
_, predicted = torch.max(output, 1) | ||
|
||
print(f"Class predicted: {idx2class[predicted.item()]}") | ||
|
||
def main(file_path): | ||
try: | ||
if (file_path): | ||
predict(sys.argv[1]) | ||
else: | ||
print(f"Usage: {sys.argv[0]} [path_to_image]") | ||
except Exception as e: | ||
print(e) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="A program to predict a type of disease specified in the leaf." | ||
) | ||
|
||
parser.add_argument( | ||
"file_path", | ||
type=str, | ||
nargs="?", | ||
help="Image file path." | ||
) | ||
|
||
args = parser.parse_args() | ||
model = LeafClassifier() | ||
model.load_state_dict(torch.load("./model_20240823_170439_4", weights_only=True)) | ||
model.eval() | ||
|
||
transform = transforms.Compose([ | ||
transforms.Resize((256, 256)), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.5,], std=[0.5,]) | ||
]) | ||
|
||
idx2class = {v: k for k, v in class_to_idx.items()} | ||
main(args.file_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,4 +8,5 @@ torchvision | |
opencv-python | ||
plantcv==3.14.3 | ||
plotnine==0.8.0 | ||
mizani==0.7.3 | ||
mizani==0.7.3 | ||
tqdm |
Binary file added
BIN
+748 Bytes
...n Loss_Training/events.out.tfevents.1724172007.paul-f4Br4s8.clusters.42paris.fr.3101446.1
Binary file not shown.
Binary file added
BIN
+748 Bytes
...Loss_Validation/events.out.tfevents.1724172007.paul-f4Br4s8.clusters.42paris.fr.3101446.2
Binary file not shown.
Binary file added
BIN
+575 Bytes
...20240820_183402/events.out.tfevents.1724171642.paul-f4Br4s8.clusters.42paris.fr.3101446.0
Binary file not shown.
Binary file added
BIN
+748 Bytes
...dation Loss_Training/events.out.tfevents.1724426106.woolinettes-MacBook-Pro.local.97663.1
Binary file not shown.
Binary file added
BIN
+748 Bytes
...tion Loss_Validation/events.out.tfevents.1724426106.woolinettes-MacBook-Pro.local.97663.2
Binary file not shown.
Binary file added
BIN
+575 Bytes
...iner_20240823_170439/events.out.tfevents.1724425479.woolinettes-MacBook-Pro.local.97663.0
Binary file not shown.
Oops, something went wrong.