-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
156 lines (122 loc) · 4.55 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import argparse
from pathlib import Path
from PIL import Image
import torch
from cli_utils import json_file_type
from model_utils import rebuild_model_from_checkpoint
from data_utils import process_image
from device_utils import get_device
def get_cli_arguments():
"""
Retrieves and parses the command line arguments provided by the user when
they run the program from a terminal window. If the user fails to provide
some or all arguments, then the default values are used for the missing
arguments.
Returns:
parse_args (Dict): data structure that stores the command line
arguments object
"""
parser = argparse.ArgumentParser(
description="A script that takes an image and a checkpoint of a model, \
then returns the top k most likely classes along with the \
probabilities",
epilog="\
Examples:\n\
* Predict flower name from an image:\n\
\tpython predict.py ./image_06743.jpg ./checkpoint.pth\n\n\
* Return top 3 most likely classes:\n\
\tpython predict.py ./image_06743.jpg ./checkpoint.pth --top_k 3\n\n\
* Use a mapping of categories to real names:\n\
\tpython predict.py ./image_06743.jpg ./checkpoint.pth \
--category_names cat_to_name.json \n\n\
* Use GPU for inference:\n\
\tpython predict.py ./image_06743.jpg --gpu\
",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"input",
type=Path,
help="path to the image to classify",
)
parser.add_argument(
"checkpoint",
type=Path,
help="path to the checkpoint of the model to use for inference",
)
parser.add_argument(
"--top_k",
type=int,
default=5,
help="the top k most probable classes",
)
parser.add_argument(
"--category_names",
type=json_file_type,
help="the path to the file containing the mapping of categories to \
real names",
)
parser.add_argument(
"--gpu",
action="store_true",
help="Use the gpu for inference",
)
return parser.parse_args()
def predict(model, image, top_k, device):
"""
Predict the class (or classes) of an image using a trained deep learning
model.
Args:
model (nn.Module): a convolutional neural network.
image (torch.FloatTensor): a tensor representing the image.
top_k (int): the number of top most likely classes to show.
device (torch.device): the device to use for inference.
return:
top_p (List): the probabilities of the top k classes.
top_idx (List): the list of top classes indices.
"""
model = model.to(device)
image = image.to(device)
model.eval()
with torch.no_grad():
logps = model.forward(image)
ps = torch.exp(logps)
top_p, top_idx = ps.topk(top_k)
return top_p, top_idx
def convert_idx_to_real_class_name(top_idx, class_to_idx, cat_to_name):
"""
Map a list of predicted indices to their corresponding class names.
Args:
top_idx (List): A list of indices representing the predicted classes.
class_to_idx (Dict): A dictionary mapping each class number to its corresponding index in the dataset.
cat_to_name (Dict): A dictionary mapping each class number to its corresponding label.
Returns:
top_class (List): The list of labels corresponding to the top classes.
"""
# Dict for reverse lookup
idx_to_cat = {v: k for k, v in class_to_idx.items()}
top_cat = [idx_to_cat[idx] for idx in top_idx]
top_class = [cat_to_name[cat] for cat in top_cat]
return top_class
def main():
cli_args = get_cli_arguments()
model = rebuild_model_from_checkpoint(cli_args.checkpoint)
class_to_idx = model.class_to_idx
device = get_device(cli_args.gpu)
with Image.open(cli_args.input) as im:
image = process_image(train=False)(im)
image = image.unsqueeze(0)
probs, classes = predict(model, image, cli_args.top_k, device)
probs, classes = probs[0].tolist(), classes[0].tolist()
if cli_args.category_names:
classes = convert_idx_to_real_class_name(
classes, class_to_idx, cli_args.category_names
)
print("-" * 43)
print(f"| {'class name':25s} | {'probability':10} |")
print("-" * 43)
for label, prob in zip(classes, probs):
print(f"| {str(label):25s} | {prob:11.2%} |")
print("-" * 43)
if __name__ == "__main__":
main()