-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict_fungus_class.py
46 lines (44 loc) · 1.99 KB
/
predict_fungus_class.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
import pickle
from scipy.ndimage import imread
from sklearn.svm import SVC
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
def photo_transform(im_name):
'''
Reads the image file and applies fast Fourier transform, then flattens it for use by the PCA and SVM algorithms.
Input: image name, such as "./Fungi Pics/JU-15B.tif".
Output: flattened and 2D array of frequency space
'''
im = imread(im_name, flatten=True) # read image and converts it to gray
Fs = np.fft.fft2(im) # run fast Fourier transform on gray image
F2 = np.fft.fftshift(Fs) # move the zero frequency component to the center
psd2D = np.abs(F2) # remove imaginary values
flat_psd = psd2D.flatten() # flatten the array
flat_psd = flat_psd.reshape((1, -1)) # make it 2D for PCA and SVM
return flat_psd
def photo_pred(flat_psd):
'''
Makes a prediction of which class the fungi is in when given a transformed image.
Input: flattened 2D fft image from photo_transform
Output: class prediction
'''
#if classifier == 'SVM':
# pca = pickle.load(open("hyphal_image_pca.obj", 'rb')) # load up-to-data PCA and SVM classifier
# clf = pickle.load(open("hyphal_image_classifier.obj", 'rb'))
# pca_psd = pca.transform(flat_psd) # transform with the PCA
# pred_label = clf.predict(pca_psd) # predict with the SVM
clf2 = pickle.load(open("hyphal_image_RF_classifier3.obj", 'rb')) # Load up-to-date Random Forest classifier
pred_label = clf2.predict(flat_psd) # predict with the Random Forest Classifier
if pred_label == 1: # return prediction decision
return "Mucormycota"
if pred_label == 2:
return "Not Mucormycota"
def main(im_name):
flat_psd = photo_transform(im_name)
pred_label = photo_pred(flat_psd)
print(pred_label)
if __name__ == "__main__":
im_name = str(input("Path name of photo:"))
#classifier = str(input("Which classifier: [SVM]/[RF]?"))
main(im_name)