Skip to content

Commit

Permalink
Remove trailing whitespaces and add final new line for each module
Browse files Browse the repository at this point in the history
  • Loading branch information
mcttn22 committed Jan 7, 2024
1 parent aff6401 commit 5033881
Show file tree
Hide file tree
Showing 31 changed files with 293 additions and 291 deletions.
2 changes: 1 addition & 1 deletion school_project/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Main package of A-level Computer Science NEA Programming Project."""

__all__ = ['models', 'frames', 'test']
__all__ = ['models', 'frames', 'test']
58 changes: 29 additions & 29 deletions school_project/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,30 @@

import pympler.tracker as tracker

from school_project.frames import (HyperParameterFrame, TrainingFrame,
LoadModelFrame, TestMNISTFrame,
from school_project.frames import (HyperParameterFrame, TrainingFrame,
LoadModelFrame, TestMNISTFrame,
TestCatRecognitionFrame, TestXORFrame)

class SchoolProjectFrame(tk.Frame):
"""Main frame of school project."""
def __init__(self, root: tk.Tk, width: int, height: int, bg: str) -> None:
"""Initialise school project pages.
Args:
root (tk.Tk): the widget object that contains this widget.
width (int): the pixel width of the frame.
height (int): the pixel height of the frame.
bg (str): the hex value or name of the frame's background colour.
Raises:
TypeError: if root, width or height are not of the correct type.
"""
super().__init__(master=root, width=width, height=height, bg=bg)
self.root = root.title("School Project")
self.WIDTH = width
self.HEIGHT = height
self.BG = bg

# Setup school project frame variables
self.hyper_parameter_frame: HyperParameterFrame
self.training_frame: TrainingFrame
Expand Down Expand Up @@ -118,14 +118,14 @@ def __init__(self, root: tk.Tk, width: int, height: int, bg: str) -> None:
command=self.enter_home_frame)

# Setup home frame
self.home_frame = tk.Frame(master=self,
width=self.WIDTH,
self.home_frame = tk.Frame(master=self,
width=self.WIDTH,
height=self.HEIGHT,
bg=self.BG)
self.title_label = tk.Label(
master=self.home_frame,
bg=self.BG,
font=('Arial', 20),
font=('Arial', 20),
text="A-level Computer Science NEA Programming Project"
)
self.about_label = tk.Label(
Expand Down Expand Up @@ -162,7 +162,7 @@ def __init__(self, root: tk.Tk, width: int, height: int, bg: str) -> None:
font=tkf.Font(size=12),
text="Load Model",
command=self.enter_load_model_frame)

# Grid home frame widgets
self.title_label.grid(row=0, column=0, columnspan=4, pady=(10,0))
self.about_label.grid(row=1, column=0, columnspan=4, pady=(10,50))
Expand All @@ -172,19 +172,19 @@ def __init__(self, root: tk.Tk, width: int, height: int, bg: str) -> None:
self.load_model_button.grid(row=4, column=2)

self.home_frame.pack()

# Setup frame attributes
self.grid_propagate(flag=False)
self.pack_propagate(flag=False)

@staticmethod
def setup_database() -> tuple[sqlite3.Connection, sqlite3.Cursor]:
"""Create a connection to the pretrained_models database file and
"""Create a connection to the pretrained_models database file and
setup base table if needed.
Returns:
a tuple of the database connection and the cursor for it.
"""
connection = sqlite3.connect(
database='school_project/saved_models.db'
Expand Down Expand Up @@ -232,14 +232,14 @@ def enter_load_model_frame(self) -> None:
)
self.load_model_frame.pack()

# Don't give option to test loaded model if no models have been saved
# Don't give option to test loaded model if no models have been saved
# for the dataset.
if len(self.load_model_frame.model_options) > 0:
self.test_loaded_model_button.pack()
self.delete_loaded_model_button.pack(pady=(5,0))

self.exit_load_model_frame_button.pack(pady=(5,0))

def exit_hyper_parameter_frame(self) -> None:
"""Unpack hyper-parameter frame and pack home frame."""
self.hyper_parameter_frame.pack_forget()
Expand Down Expand Up @@ -269,7 +269,7 @@ def enter_training_frame(self) -> None:
self.exit_hyper_parameter_frame_button.pack_forget()
self.training_frame = TrainingFrame(
root=self,
width=self.WIDTH,
width=self.WIDTH,
height=self.HEIGHT,
bg=self.BG,
model=self.model,
Expand All @@ -282,7 +282,7 @@ def enter_training_frame(self) -> None:
def manage_training(self, train_thread: threading.Thread) -> None:
"""Wait for model training thread to finish,
then plot training losses on training frame.
Args:
train_thread (threading.Thread):
the thread running the model's train() method.
Expand All @@ -308,7 +308,7 @@ def test_created_model(self) -> None:
self.training_frame.pack_forget()
self.test_created_model_button.pack_forget()
if self.hyper_parameter_frame.dataset == "MNIST":
self.test_frame = TestMNISTFrame(
self.test_frame = TestMNISTFrame(
root=self,
width=self.WIDTH,
height=self.HEIGHT,
Expand All @@ -319,7 +319,7 @@ def test_created_model(self) -> None:
elif self.hyper_parameter_frame.dataset == "Cat Recognition":
self.test_frame = TestCatRecognitionFrame(
root=self,
width=self.WIDTH,
width=self.WIDTH,
height=self.HEIGHT,
bg=self.BG,
use_gpu=self.hyper_parameter_frame.use_gpu,
Expand All @@ -335,7 +335,7 @@ def test_created_model(self) -> None:
self.manage_testing(test_thread=self.test_frame.test_thread)

def test_loaded_model(self) -> None:
"""Load saved model from load model frame, unpack load model frame,
"""Load saved model from load model frame, unpack load model frame,
pack test frame for the dataset and begin managing the test thread."""
self.saving_model = False
try:
Expand All @@ -347,7 +347,7 @@ def test_loaded_model(self) -> None:
self.delete_loaded_model_button.pack_forget()
self.exit_load_model_frame_button.pack_forget()
if self.load_model_frame.dataset == "MNIST":
self.test_frame = TestMNISTFrame(
self.test_frame = TestMNISTFrame(
root=self,
width=self.WIDTH,
height=self.HEIGHT,
Expand Down Expand Up @@ -376,13 +376,13 @@ def test_loaded_model(self) -> None:
def manage_testing(self, test_thread: threading.Thread) -> None:
"""Wait for model test thread to finish,
then plot results on test frame.
Args:
test_thread (threading.Thread):
the thread running the model's predict() method.
Raises:
TypeError: if test_thread is not of type threading.Thread.
"""
if not test_thread.is_alive():
self.test_frame.plot_results(model=self.model)
Expand All @@ -395,7 +395,7 @@ def manage_testing(self, test_thread: threading.Thread) -> None:
self.after(1_000, self.manage_testing, test_thread)

def save_model(self) -> None:
"""Save the model, save the model information to the database, then
"""Save the model, save the model information to the database, then
enter the home frame."""
model_name = self.save_model_name_entry.get()

Expand Down Expand Up @@ -480,19 +480,19 @@ def enter_home_frame(self) -> None:
self.home_frame.pack()
summary_tracker.create_summary() # BUG: Object summary seems to reduce
# memory leak greatly

def main() -> None:
"""Entrypoint of project."""
root = tk.Tk()
school_project_frame = SchoolProjectFrame(root=root, width=1280,
height=835, bg='white')
school_project_frame.pack(side='top', fill='both', expand=True)
root.mainloop()

# Stop model training when GUI closes
if school_project_frame.model != None:
if school_project_frame.model is not None:
school_project_frame.model.set_running(value=False)

if __name__ == "__main__":
summary_tracker = tracker.SummaryTracker() # Setup object tracker
main()
main()
2 changes: 1 addition & 1 deletion school_project/frames/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .load_model import LoadModelFrame
from .test_model import TestMNISTFrame, TestCatRecognitionFrame, TestXORFrame

__all__ = ['create_model', 'load_model', 'test_model']
__all__ = ['create_model', 'load_model', 'test_model']
48 changes: 25 additions & 23 deletions school_project/frames/create_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

class HyperParameterFrame(tk.Frame):
"""Frame for hyper-parameter page."""
def __init__(self, root: tk.Tk, width: int,
def __init__(self, root: tk.Tk, width: int,
height: int, bg: str, dataset: str) -> None:
"""Initialise hyper-parameter frame widgets.
Args:
root (tk.Tk): the widget object that contains this widget.
width (int): the pixel width of the frame.
Expand All @@ -24,21 +24,21 @@ def __init__(self, root: tk.Tk, width: int,
('MNIST', 'Cat Recognition' or 'XOR')
Raises:
TypeError: if root, width or height are not of the correct type.
"""
super().__init__(master=root, width=width, height=height, bg=bg)
self.root = root
self.WIDTH = width
self.HEIGHT = height
self.BG = bg

# Setup hyper-parameter frame variables
self.dataset = dataset
self.use_gpu: bool
self.default_hyper_parameters = self.load_default_hyper_parameters(
dataset=dataset
)

# Setup widgets
self.title_label = tk.Label(master=self,
bg=self.BG,
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(self, root: tk.Tk, width: int,
self.model_status_label = tk.Label(master=self,
bg=self.BG,
font=('Arial', 15))

# Pack widgets
self.title_label.grid(row=0, column=0, columnspan=3)
self.about_label.grid(row=1, column=0, columnspan=3)
Expand All @@ -129,13 +129,13 @@ def __init__(self, root: tk.Tk, width: int,
self.use_gpu_check_button.grid(row=3, column=2, pady=(30, 0))
self.model_status_label.grid(row=5, column=0,
columnspan=3, pady=50)

def load_default_hyper_parameters(self, dataset: str) -> dict[
str,
str,
str | int | list[int] | float
]:
"""Load the dataset's default hyper-parameters from the json file.
Args:
dataset (str): the name of the dataset to load hyper-parameters
for. ('MNIST', 'Cat Recognition' or 'XOR')
Expand All @@ -144,7 +144,7 @@ def load_default_hyper_parameters(self, dataset: str) -> dict[
"""
with open('school_project/frames/hyper-parameter-defaults.json') as f:
return json.load(f)[dataset]

def create_model(self) -> object:
"""Create and return a Model using the hyper-parameters set.
Expand All @@ -171,10 +171,12 @@ def create_model(self) -> object:
from school_project.models.cpu.cat_recognition import CatRecognitionModel as Model
elif self.dataset == "XOR":
from school_project.models.cpu.xor import XORModel as Model
model = Model(hidden_layers_shape = [int(neuron_count) for neuron_count in hidden_layers_shape_input],
train_dataset_size = self.train_dataset_size_scale.get(),
learning_rate = self.learning_rate_scale.get(),
use_relu = self.use_relu_check_button_var.get())
model = Model(
hidden_layers_shape = [int(neuron_count) for neuron_count in hidden_layers_shape_input],
train_dataset_size = self.train_dataset_size_scale.get(),
learning_rate = self.learning_rate_scale.get(),
use_relu = self.use_relu_check_button_var.get()
)
model.create_model_values()

else:
Expand All @@ -197,14 +199,14 @@ def create_model(self) -> object:
)
raise ImportError
return model

class TrainingFrame(tk.Frame):
"""Frame for training page."""
def __init__(self, root: tk.Tk, width: int,
height: int, bg: str,
model: object, epoch_count: int) -> None:
"""Initialise training frame widgets.
Args:
root (tk.Tk): the widget object that contains this widget.
width (int): the pixel width of the frame.
Expand All @@ -214,14 +216,14 @@ def __init__(self, root: tk.Tk, width: int,
epoch_count (int): the number of training epochs.
Raises:
TypeError: if root, width or height are not of the correct type.
"""
super().__init__(master=root, width=width, height=height, bg=bg)
self.root = root
self.WIDTH = width
self.HEIGHT = height
self.BG = bg

# Setup widgets
self.model_status_label = tk.Label(master=self,
bg=self.BG,
Expand All @@ -234,11 +236,11 @@ def __init__(self, root: tk.Tk, width: int,
figure=self.loss_figure,
master=self
)

# Pack widgets
self.model_status_label.pack(pady=(30,0))
self.training_progress_label.pack(pady=30)

# Start training thread
self.model_status_label.configure(
text="Training weights and biases...",
Expand All @@ -252,10 +254,10 @@ def __init__(self, root: tk.Tk, width: int,

def plot_losses(self, model: object) -> None:
"""Plot losses of Model training.
Args:
model (object): the Model object thats been trained.
"""
self.model_status_label.configure(
text=f"Weights and biases trained in {model.training_time}s",
Expand All @@ -267,4 +269,4 @@ def plot_losses(self, model: object) -> None:
graph.set_xlabel("Epochs")
graph.set_ylabel("Loss Value")
graph.plot(np.squeeze(model.train_losses))
self.loss_canvas.get_tk_widget().pack()
self.loss_canvas.get_tk_widget().pack()
Loading

0 comments on commit 5033881

Please sign in to comment.