-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMNIST.py
29 lines (22 loc) · 1.03 KB
/
MNIST.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from sklearn import datasets
from sklearn import neighbors, metrics
from sklearn.neighbors import KNeighborsClassifier as kNN
import matplotlib.pyplot as plt
Digits = datasets.load_digits()
Imglabels = [x for x in list(zip(Digits.images, Digits.target)) if x[1] == 7 or x[1] == 3]
for ind, (image, label) in enumerate(Imglabels[:4]):
plt.subplot(2, 4, ind + 1)
plt.axis('off')
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.title('Training: label')
plt.show()
num = len(Digits.images) #calculates number
imgs = Digits.images.reshape((num, -1))
labs = Digits.target
y_trainset, x_trainset = labs[:int(num*.7)].reshape(-1,), imgs[:int(num*.7)]
y_testset, x_testset= labs[int(num*.7):].reshape(-1,), imgs[int(num*.7):]
neighbor = kNN(n_neighbors=3)
neighbor.fit(x_trainset, y_trainset)
new_val = neighbor.predict(x_testset)
print("kNN classifirer reports: %s:\n%s\n" % (neighbor, metrics.classification_report(y_testset, new_val)))
print("Confusion matrix is:", metrics.confusion_matrix(y_testset, new_val))