-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTensorFlowImageClassifier.java
136 lines (126 loc) · 5.5 KB
/
TensorFlowImageClassifier.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package org.faceit.demo;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.os.Trace;
import android.support.v4.os.EnvironmentCompat;
import android.util.Log;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Vector;
import org.faceit.demo.Classifier.Recognition;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
public class TensorFlowImageClassifier implements Classifier {
private static final int MAX_RESULTS = 1;
private static final String TAG = "TensorFlowImageClassifier";
private static final float THRESHOLD = 0.1f;
private float[] floatValues;
private int imageMean;
private float imageStd;
private TensorFlowInferenceInterface inferenceInterface;
private String inputName;
private int inputSize;
private int[] intValues;
private Vector<String> labels = new Vector();
private boolean logStats = false;
private String outputName;
private String[] outputNames;
private float[] outputs;
class C01971 implements Comparator<Recognition> {
C01971() {
}
public int compare(Recognition lhs, Recognition rhs) {
return Float.compare(rhs.getConfidence().floatValue(), lhs.getConfidence().floatValue());
}
}
private TensorFlowImageClassifier() {
}
public static Classifier create(AssetManager assetManager, String modelFilename, String labelFilename, int inputSize, int imageMean, float imageStd, String inputName, String outputName) {
IOException e;
TensorFlowImageClassifier c = new TensorFlowImageClassifier();
c.inputName = inputName;
c.outputName = outputName;
String actualFilename = labelFilename.split("file:///android_asset/")[1];
Log.i(TAG, "Reading labels from: " + actualFilename);
try {
BufferedReader br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
while (true) {
try {
String line = br.readLine();
if (line != null) {
c.labels.add(line);
} else {
br.close();
c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
int numClasses = (int) c.inferenceInterface.graphOperation(outputName).output(0).shape().size(1);
Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses);
c.inputSize = inputSize;
c.imageMean = imageMean;
c.imageStd = imageStd;
c.outputNames = new String[]{outputName};
c.intValues = new int[(inputSize * inputSize)];
c.floatValues = new float[((inputSize * inputSize) * 3)];
c.outputs = new float[numClasses];
return c;
}
} catch (IOException e2) {
e = e2;
BufferedReader bufferedReader = br;
}
}
} catch (IOException e3) {
e = e3;
throw new RuntimeException("Problem reading label file!", e);
}
}
public List<Recognition> recognizeImage(Bitmap bitmap) {
int i;
Trace.beginSection("recognizeImage");
Trace.beginSection("preprocessBitmap");
bitmap.getPixels(this.intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
for (i = 0; i < this.intValues.length; i++) {
int val = this.intValues[i];
this.floatValues[(i * 3) + 0] = ((float) (((val >> 16) & 255) - this.imageMean)) / this.imageStd;
this.floatValues[(i * 3) + 1] = ((float) (((val >> 8) & 255) - this.imageMean)) / this.imageStd;
this.floatValues[(i * 3) + 2] = ((float) ((val & 255) - this.imageMean)) / this.imageStd;
}
Trace.endSection();
Trace.beginSection("feed");
this.inferenceInterface.feed(this.inputName, this.floatValues, 1, (long) this.inputSize, (long) this.inputSize, 3);
Trace.endSection();
Trace.beginSection("run");
this.inferenceInterface.run(this.outputNames, this.logStats);
Trace.endSection();
Trace.beginSection("fetch");
this.inferenceInterface.fetch(this.outputName, this.outputs);
Trace.endSection();
PriorityQueue<Recognition> pq = new PriorityQueue(3, new C01971());
i = 0;
while (i < this.outputs.length) {
if (this.outputs[i] > THRESHOLD) {
pq.add(new Recognition("" + i, this.labels.size() > i ? (String) this.labels.get(i) : EnvironmentCompat.MEDIA_UNKNOWN, Float.valueOf(this.outputs[i]), null));
}
i++;
}
ArrayList<Recognition> recognitions = new ArrayList();
int recognitionsSize = Math.min(pq.size(), 1);
for (i = 0; i < recognitionsSize; i++) {
recognitions.add(pq.poll());
}
Trace.endSection();
return recognitions;
}
public void enableStatLogging(boolean logStats) {
this.logStats = logStats;
}
public String getStatString() {
return this.inferenceInterface.getStatString();
}
public void close() {
this.inferenceInterface.close();
}
}