-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexamine_network.py
104 lines (78 loc) · 2.61 KB
/
examine_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Kevin Heleodoro - Examine the network's architecture
# ----- Import Statements -------- #
import cv2
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from main import MyNetwork, print_border, load_data
# ----- Global Variables --------- #
# ----- Class Definitions -------- #
# ----- Function Definitions ----- #
# Analyze the given layer of the network
def analyze_layer(layer):
print(f"Layer: {layer}")
print(f"Layer weight shape: {layer.weight.shape}")
# Visualize the filters using pyplot
plt.figure()
print("Visualizing the filters...")
for i in range(10):
plt.subplot(3, 4, i + 1)
plt.imshow(layer.weight[i][0].detach().numpy())
plt.axis("off")
plt.xticks([])
plt.yticks([])
plt.show()
# Show effects of the first layer using OpenCV filter 2D function
def show_effects(layer, image):
image_tensor = image.squeeze().numpy()
print("Iterating through filters in layer...")
plt.figure()
with torch.no_grad():
for i in range(10):
filter_kernel = layer.weight[i][0].cpu().detach().numpy()
filtered_image = cv2.filter2D(image_tensor, -1, filter_kernel)
plt.subplot(5, 4, 2 * i + 1)
plt.imshow(filter_kernel, cmap="gray")
plt.axis("off")
plt.xticks([])
plt.yticks([])
plt.subplot(5, 4, 2 * i + 2)
plt.imshow(filtered_image, cmap="gray")
plt.axis("off")
plt.xticks([])
plt.yticks([])
print("Filters applied successfully!")
plt.tight_layout()
plt.show()
# ----- Main Code ---------------- #
# Examine the network's architecture
def main():
print_border()
print("Examine the network's architecture")
# Load training data
print_border()
print("Loading training data...")
train_loader, _ = load_data()
# Load the network
print_border()
print("Loading the network...")
network = MyNetwork()
model_path = "results/main/mnist_model.pth"
network.load_state_dict(torch.load(model_path))
# Print the network architecture
print_border()
print(f"Model loaded from: {model_path}")
print(network)
# Analyze first layer
print_border()
print("Analyzing the first layer...")
analyze_layer(network.conv1)
# Apply filters to first image
print_border()
print("Applying first layer filters to the test image...")
image_to_filter = train_loader.dataset[0][0]
show_effects(network.conv1, image_to_filter)
print("Terminating the program")
if __name__ == "__main__":
main()