forked from enrignagna/CovidNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkernel_visualization.py
32 lines (29 loc) · 1.09 KB
/
kernel_visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import matplotlib.pyplot as plt
import numpy as np
#funzione pubblica per calcolare e visualizzare i kernel.
def visualize(model, layer_name):
kernel = __extract_filter(model, layer_name)
__vis_filter(kernel)
#funzione per estrarre i kernels usati in uno specifico livello di uno specifico modello.
def __extract_filter(model, layer_name):
filters = model.get_layer(layer_name).get_weights()[0]
filters = filters[:,:,:,:6]
#normalizzazione in un valore tra 0-1
f_min, f_max = filters.min(), filters.max()
filters = (filters - f_min)/(f_max - f_min)
return filters
#funzione per visualizzare i kernels
def __vis_filter(filters):
n_filters = min([6,filters.shape[3]])
fig, ax = plt.subplots(3, n_filters,figsize=(1.5*n_filters,3))
for i in range(n_filters):
f = filters[:,:,:,i]
for j in range(3):
if n_filters > 1:
ax[j,i].imshow(f[:,:,j],cmap="gray")
ax[j,i].axis("off")
else:
ax[j].imshow(f[:,:,j],cmap="gray")
ax[j].axis("off")
plt.tight_layout()
plt.show()