Skip to content

Commit

Permalink
[update] document code
Browse files Browse the repository at this point in the history
  • Loading branch information
myatmyintzuthin committed Nov 1, 2022
1 parent 189f4a0 commit ed55226
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 20 deletions.
4 changes: 2 additions & 2 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from models.vgg import VGG
from models.mobilenetv2 import MobileNetV2

class convert_model():
class ConvertModel:
def __init__(self, model: str, variant: str, width_multi: float, num_class: str) -> None:
self.model = model
self.variant = variant
Expand Down Expand Up @@ -75,5 +75,5 @@ def initialize_weights(self):

opt = parser.parse_args()

prepare_model = convert_model(opt.model, opt.variant, opt.width_multi, opt.num_class)
prepare_model = ConvertModel(opt.model, opt.variant, opt.width_multi, opt.num_class)
prepare_model.initialize_weights()
7 changes: 6 additions & 1 deletion core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from rich.progress import track

def train_step(model, dataloader, loss_fn, optimizer, device):

'''process for each training step
'''
model.train()

train_loss, train_acc = 0, 0
Expand Down Expand Up @@ -38,6 +39,8 @@ def train_step(model, dataloader, loss_fn, optimizer, device):

def test_step(model, dataloader, loss_fn, device):

''' process for each testing step
'''
model.eval()

test_loss, test_acc = 0, 0
Expand Down Expand Up @@ -68,6 +71,8 @@ def test_step(model, dataloader, loss_fn, device):

def train(model, train_dataloader, test_dataloader, optimizer, loss_fn, scheduler, epochs, log, device):

''' training loop
'''
results = {'train_loss': [], 'train_acc': [],
'test_loss': [], 'test_acc': []}

Expand Down
6 changes: 6 additions & 0 deletions dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@


def download_dataset(data_path, image_path, log):
''' download dataset and unzip it
'''
if image_path.is_dir():
log.info(f"{image_path} directory exists.")
else:
Expand All @@ -27,6 +29,8 @@ def download_dataset(data_path, image_path, log):


def find_classes(directory: str):
''' find classes in dataset
'''
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f'Couldn\'t find any classes in {directory}')
Expand Down Expand Up @@ -61,6 +65,8 @@ def __getitem__(self, index: int):
class CustomDataloader():
def __init__(self,data_dir: str, img_path: str, BATCH_SIZE: int, log: str, num_worker: int, shuffle: bool = True) -> None:

''' custom dataloader class
'''
self.batchsize = BATCH_SIZE
self.shuffle = shuffle
self.num_worker = num_worker
Expand Down
12 changes: 8 additions & 4 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import convert


class Evaluation():
class Evaluation:
def __init__(self, model_path, test_loader, config) -> None:

self.model_path = model_path
Expand All @@ -20,13 +20,16 @@ def __init__(self, model_path, test_loader, config) -> None:
'cuda' if torch.cuda.is_available() else 'cpu')

def run(self):
''' run process
'''
model = self.prepare_model()
classification_report = self.evaluate(model)
return classification_report

def prepare_model(self):

choose_model = convert.convert_model(
''' load model
'''
choose_model = convert.ConvertModel(
self.model_name, self.variant, self.width_multi, len(self.class_name))
model = choose_model.load_model()
ckpt = torch.load(self.model_path)
Expand All @@ -35,7 +38,8 @@ def prepare_model(self):
return model

def evaluate(self, model):

''' model evaluation
'''
model.eval()
actual_label, pred_label = [],[]
with torch.inference_mode():
Expand Down
12 changes: 8 additions & 4 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import utils.utils as utils
import convert

class Inference():
class Inference:
def __init__(self, opt) -> None:

self.config = utils.yaml_parser(opt.cfg)
Expand All @@ -33,12 +33,15 @@ def __init__(self, opt) -> None:
])

def run(self):
''' run process
'''
model = self.prepare_model()
self.inference(model)

def prepare_model(self):

choose_model = convert.convert_model(
''' load model
'''
choose_model = convert.ConvertModel(
self.model_name, self.variant, self.width_multi, len(self.class_name))
model = choose_model.load_model()
ckpt = torch.load(self.model_path)
Expand All @@ -47,7 +50,8 @@ def prepare_model(self):
return model

def inference(self, model):

''' model inference
'''
model.eval()
image_files = glob(os.path.join(self.image_path, '*.jpg'))

Expand Down
5 changes: 4 additions & 1 deletion metrics/metrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def log_table(rich_table):
return Text.from_ansi(capture.get())

def plot_confusion(actual, pred, class_name, save_path, save=True):
''' confusion matrix calculation
'''

actual = np.array(actual)
pred = np.array(pred)
Expand Down Expand Up @@ -50,7 +52,8 @@ def plot_confusion(actual, pred, class_name, save_path, save=True):
return cm

def classification_report(cm, labels):

''' precision, recall, F1-score, accuracy calculation
'''
Pre, Rc, F1 = [],[],[]
# calculate tp,tn,fp,fn
total = np.sum(cm)
Expand Down
22 changes: 14 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
torch.cuda.manual_seed(42)


class Training():
class Training:
def __init__(self, opt) -> None:
self.config = utils.yaml_parser(opt.cfg)
self.eval_opt = opt.eval
Expand Down Expand Up @@ -45,7 +45,9 @@ def __init__(self, opt) -> None:
'cuda' if torch.cuda.is_available() else 'cpu')

def run(self):

''' run process
train -> evaluation
'''
if not self.eval_opt:

exp_name = self.model_name+self.variant
Expand All @@ -65,28 +67,31 @@ def run(self):
self.eval(self.eval_model, dataloader, eval_log)

def dataloader(self, cfg, log):

''' data preparation
'''
data_path = Path(cfg['root'])
image_path = data_path/cfg['dataset_name']
dataloader = dataset.CustomDataloader(
data_path, image_path, self.BATCH_SIZE, log, self.num_worker, shuffle=True)
return dataloader

def prepare_model(self):

choose_model = convert.convert_model(
''' load model
'''
choose_model = convert.ConvertModel(
self.model_name, self.variant, self.width_multi, self.num_class)
model = choose_model.load_model()

ckpt = torch.load(self.pretrained_path)
model = utils.load_ckpt(model, ckpt)

model = model.to(self.device)
# summary(model, input_size=(self.BATCH_SIZE, 3, 224, 224))
summary(model, input_size=(self.BATCH_SIZE, 3, 224, 224))
return model

def train(self, model, dataloader):

''' model training
'''
# dataloader
train_loader, valid_loader = dataloader.train_dataloader(), dataloader.valid_dataloader()
# lost function
Expand Down Expand Up @@ -115,7 +120,8 @@ def train(self, model, dataloader):
torch.save(model.state_dict(), self.save_model)

def eval(self, save_model, dataloader, log):
# evaluate model
''' model evaluation
'''
log.info("Evaluation Starts....")
test_loader = dataloader.test_dataloader()
evaluation = evaluate.Evaluation(save_model, test_loader, self.config)
Expand Down
6 changes: 6 additions & 0 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from datetime import datetime

def setup_logger(log_name):
'''
setup rich logger
'''
logger = logging.getLogger(__name__)

rh = RichHandler()
Expand Down Expand Up @@ -89,6 +92,9 @@ def plot_curves(results, save_path):
plt.savefig(save_path)

def load_ckpt(model, ckpt):
'''
load trained model
'''
model_state_dict = model.state_dict()
load_dict = {}
for key_model, v in model_state_dict.items():
Expand Down

0 comments on commit ed55226

Please sign in to comment.