Skip to content

Commit

Permalink
Add interface to display confusion matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
johnisom committed Dec 13, 2023
1 parent f50b6dd commit 5a1cddf
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 19 deletions.
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from threading import Thread
from src.fires_info import get_fires_dataframe
from src.location_info import get_counties_geodf,get_fips_codes_dataframe
from src.prediction import get_fips_encoder, get_fips_model, get_lonlat_model, joblib_objects_unpacked
from src.prediction import get_fips_encoder, get_fips_model, get_lonlat_model, joblib_objects_present

# Load the database data and pygris geographical data in the background
def fires_info_thread_target():
Expand All @@ -20,7 +20,7 @@ def location_info_thread_target():
fires_info_thread.start()
location_info_thread.start()

enable_ml = joblib_objects_unpacked()
enable_ml = joblib_objects_present()
if enable_ml:
# Load the ML prediction models in the background
def prediction_thread_target():
Expand Down
41 changes: 28 additions & 13 deletions src/gui/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .custom_widgets import NotebookFrame, Title, Subtitle, DefaultEntry, DatetimeEntry
from ..location_info import get_fips_codes_dataframe, are_coordinates_inside_usa
from ..prediction import run_fips_model_prediction, run_lonlat_model_prediction
from ..bindings import plot_lonlat_confusion_matrix, plot_fipscode_confusion_matrix
import datetime

class PredictionsFrame(NotebookFrame):
Expand Down Expand Up @@ -47,9 +48,9 @@ def __init__(self, *args, **kwargs):
self.set_up_state_county_variables()

title = Title(self, text='Predictions')
subframe = ttk.LabelFrame(self, text='Predict cause of fire by giving some parameters')
subframe = ttk.LabelFrame(self, text='Predict cause of fire by giving some parameters', padding=5)

input_frame = ttk.Frame(subframe, padding=5)
input_frame = ttk.Frame(subframe)
discovery_datetime_label = ttk.Label(input_frame, text='Discovery datetime: ')
self.discovery_datetime_entry = DatetimeEntry(input_frame)
contained_datetime_label = ttk.Label(input_frame, text='Contained datetime: ')
Expand All @@ -71,8 +72,7 @@ def __init__(self, *args, **kwargs):
latitude_label = ttk.Label(location_input_frame, text='Latitude: ')
self.latitude_entry = DefaultEntry(location_input_frame, default_text='42.124', validate='key', validatecommand=(self.register(PredictionsFrame.check_latitude), '%P'))
predict_button = ttk.Button(input_frame, text='Predict fire cause', command=self.run_prediction)
display_frame = ttk.Frame(subframe, padding=5)
display_frame_subtitle = Subtitle(display_frame, text='Predicted Category Results')
display_frame = ttk.LabelFrame(subframe, padding=5, text='Predicted Category Results')
predicted_category_label = ttk.Label(display_frame, text='Top predicted category:')
self.predicted_category_var = StringVar()
predicted_category = ttk.Label(display_frame, textvariable=self.predicted_category_var, font=('Helvetica', 12))
Expand All @@ -99,11 +99,14 @@ def __init__(self, *args, **kwargs):
self.probability_3_percent_var = StringVar()
probability_3_percent = ttk.Label(probabilities_frame, textvariable=self.probability_3_percent_var)
probability_3_percent_label = ttk.Label(probabilities_frame, text='%')
confusion_matrix_buttons_frame = ttk.LabelFrame(subframe, text="ML Models' Confusion Matrices")
fips_confusion_matrix_button = ttk.Button(confusion_matrix_buttons_frame, text='State/County Model', command=self.show_fips_confusion_matrix)
lonlat_confusion_matrix_button = ttk.Button(confusion_matrix_buttons_frame, text='Longitude/Latitude Model', command=self.show_lonlat_confusion_matrix)

# Set items on the grid
title.grid(row=0, column=0, sticky=(N, E, W))
subframe.grid(row=1, column=0, sticky=(N, S, E, W))
input_frame.grid(row=0, column=0, sticky=NSEW)
subframe.grid(row=1, column=0, sticky=NSEW)
input_frame.grid(row=0, column=0, columnspan=2, sticky=NSEW)
discovery_datetime_label.grid(row=0, column=0, sticky=E)
self.discovery_datetime_entry.grid(row=0, column=1, sticky=W)
contained_datetime_label.grid(row=1, column=0, sticky=E)
Expand All @@ -122,10 +125,9 @@ def __init__(self, *args, **kwargs):
self.latitude_entry.grid(row=1, column=4, sticky=W)
predict_button.grid(row=6, column=0, columnspan=4, pady=5, sticky=NSEW)
display_frame.grid(row=1, column=0, sticky=NSEW)
display_frame_subtitle.grid(row=0, column=0, columnspan=3, sticky=(N, E, W))
predicted_category_label.grid(row=1, column=1, sticky=E)
predicted_category.grid(row=2, column=1, sticky=E)
probabilities_frame.grid(row=1, column=2, rowspan=4, sticky=(N, S, E))
predicted_category_label.grid(row=0, column=0, sticky=E)
predicted_category.grid(row=0, column=1, sticky=W)
probabilities_frame.grid(row=1, column=0, rowspan=2, columnspan=3, sticky=NSEW)
probabilities_label.grid(row=0, column=0, columnspan=5, sticky=(N, E, W))
probability_1_label.grid(row=1, column=0, sticky=E)
probability_1_cause.grid(row=1, column=1, sticky=E)
Expand All @@ -142,27 +144,34 @@ def __init__(self, *args, **kwargs):
probability_3_colon_label.grid(row=3, column=2)
probability_3_percent.grid(row=3, column=3, sticky=W)
probability_3_percent_label.grid(row=3, column=4, sticky=W)
confusion_matrix_buttons_frame.grid(row=1, column=1, sticky=NSEW)
fips_confusion_matrix_button.grid(row=0, column=0, sticky=NSEW, padx=4, pady=4)
lonlat_confusion_matrix_button.grid(row=1, column=0, sticky=NSEW, padx=4, pady=4)

# Configure the grid
self.rowconfigure((0,), weight=1)
self.rowconfigure((1,), weight=5)
self.columnconfigure((0,), weight=1)
subframe.rowconfigure((0,), weight=2)
subframe.rowconfigure((1,), weight=1)
subframe.columnconfigure((0,), weight=1)
subframe.columnconfigure((0,), weight=2)
subframe.columnconfigure((1,), weight=1)
input_frame.rowconfigure((0, 1, 2, 3, 4, 5), weight=1)
input_frame.rowconfigure((6,), weight=2)
input_frame.columnconfigure((0, 1, 2, 3), weight=1)
location_input_frame.rowconfigure((0, 1), weight=1)
location_input_frame.columnconfigure((0, 1, 3, 4), weight=5)
location_input_frame.columnconfigure((2,), weight=1)
display_frame.rowconfigure((0,), weight=2)
display_frame.rowconfigure((1, 2, 3, 4), weight=1)
display_frame.columnconfigure((0, 1, 2), weight=1)
display_frame.rowconfigure((1, 2), weight=1)
display_frame.columnconfigure((0, 1), weight=1)
display_frame.columnconfigure((2,), weight=2)
probabilities_frame.rowconfigure((0,), weight=3)
probabilities_frame.rowconfigure((1, 2, 3), weight=2)
probabilities_frame.columnconfigure((0, 2, 3, 4), weight=1)
probabilities_frame.columnconfigure((1), weight=2)
confusion_matrix_buttons_frame.rowconfigure((0, 1), weight=1)
confusion_matrix_buttons_frame.columnconfigure((0,), weight=1)

def set_up_state_county_variables(self):
self.fips_codes_df = get_fips_codes_dataframe()
Expand Down Expand Up @@ -261,3 +270,9 @@ def display_predicted_info(self, probabilities):
self.probability_1_percent_var.set(f'{probabilities[0][1] * 100:.2f}')
self.probability_2_percent_var.set(f'{probabilities[1][1] * 100:.2f}')
self.probability_3_percent_var.set(f'{probabilities[2][1] * 100:.2f}')

def show_fips_confusion_matrix(self):
plot_lonlat_confusion_matrix()

def show_lonlat_confusion_matrix(self):
plot_fipscode_confusion_matrix()
4 changes: 2 additions & 2 deletions src/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def plot_lonlat_model_confusion_matrix():
df = fires_df[['fire_size', 'longitude', 'latitude', 'discovery_datetime', 'contained_datetime', 'stat_cause_code']].dropna()
X = df.drop('stat_cause_code', axis=1)
y = df['stat_cause_code'] - 1
fig, ax = plt.subplots(figsize=[12, 8])
fig, ax = plt.subplots(figsize=[10, 9])
ConfusionMatrixDisplay.from_predictions(y_true=y, y_pred=model.predict(X), normalize='true', xticks_rotation='vertical', display_labels=STAT_CAUSE_CODE_TO_DESCR.values(), values_format='.2f', ax=ax)
ax.set_title('Confusion Matrix for the Longitude/Latitude prediction model.')
fig.tight_layout()
Expand All @@ -81,7 +81,7 @@ def plot_fipscode_model_confusion_matrix():
df.loc[:, ['combined_fips_code']] = encoder.transform(df[['combined_fips_code']])[0].astype(int)
X = df.drop('stat_cause_code', axis=1)
y = df['stat_cause_code'] - 1
fig, ax = plt.subplots(figsize=[10, 10])
fig, ax = plt.subplots(figsize=[10, 9])
ConfusionMatrixDisplay.from_predictions(y_true=y, y_pred=model.predict(X), normalize='true', xticks_rotation='vertical', display_labels=STAT_CAUSE_CODE_TO_DESCR.values(), values_format='.2f', ax=ax)
ax.set_title('Confusion Matrix for the State/County prediction model.')
fig.tight_layout()
Expand Down
4 changes: 2 additions & 2 deletions src/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
11: 'Powerline', 12: 'Structure', 13: 'Missing/Undefined'
}

def joblib_objects_unpacked():
def joblib_objects_present():
return FIPS_MODEL_PATH.is_file() and FIPS_ENCODER_PATH.is_file() and LONLAT_MODEL_PATH.is_file()

_fips_encoder = None
Expand Down Expand Up @@ -48,7 +48,7 @@ def get_lonlat_model():
def run_fips_model_prediction(fire_size, combined_fips_code, discovery_datetime, contained_datetime):
encoder = get_fips_encoder()
classifier = get_fips_model()
encoded_fips_code = encoder.transform([[combined_fips_code]])[0].astype(int)[0]
encoded_fips_code = encoder.transform(pd.DataFrame(data=[[combined_fips_code]], columns=['combined_fips_code']))[0].astype(int)[0]
df = pd.DataFrame(
data=[[fire_size, encoded_fips_code, discovery_datetime.timestamp(), contained_datetime.timestamp()]],
columns=['fire_size', 'combined_fips_code', 'discovery_datetime', 'contained_datetime']
Expand Down

0 comments on commit 5a1cddf

Please sign in to comment.