forked from bytedeco/javacv
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CaffeGooglenet.java
105 lines (88 loc) · 3.96 KB
/
CaffeGooglenet.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
/*
* JavaCV version of OpenCV caffe_googlenet.cpp
* https://github.com/ludv1x/opencv_contrib/blob/master/modules/dnn/samples/caffe_googlenet.cpp
*
* Paolo Bolettieri <[email protected]>
*/
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.bytedeco.opencv.opencv_core.*;
import org.bytedeco.opencv.opencv_dnn.*;
import org.bytedeco.opencv.opencv_imgproc.*;
import static org.bytedeco.opencv.global.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_dnn.*;
import static org.bytedeco.opencv.global.opencv_imgcodecs.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*;
public class CaffeGooglenet {
/* Find best class for the blob (i. e. class with maximal probability) */
public static void getMaxClass(Mat probBlob, Point classId, double[] classProb) {
Mat probMat = probBlob.reshape(1, 1); //reshape the blob to 1x1000 matrix
minMaxLoc(probMat, null, classProb, null, classId, null);
}
public static List<String> readClassNames() {
String filename = "synset_words.txt";
List<String> classNames = null;
try (BufferedReader br = new BufferedReader(new FileReader(new File(filename)))) {
classNames = new ArrayList<String>();
String name = null;
while ((name = br.readLine()) != null) {
classNames.add(name.substring(name.indexOf(' ')+1));
}
} catch (IOException ex) {
System.err.println("File with classes labels not found " + filename);
System.exit(-1);
}
return classNames;
}
public static void main(String[] args) throws Exception {
String modelTxt = "bvlc_googlenet.prototxt";
String modelBin = "bvlc_googlenet.caffemodel";
String imageFile = (args.length > 0) ? args[0] : "space_shuttle.jpg";
//! [Initialize network]
Net net = null;
try { //Try to import Caffe GoogleNet model
net = readNetFromCaffe(modelTxt, modelBin);
} catch (Exception e) { //Importer can throw errors, we will catch them
e.printStackTrace();
}
if (net == null || net.empty()) {
System.err.println("Can't load network by using the following files: ");
System.err.println("prototxt: " + modelTxt);
System.err.println("caffemodel: " + modelBin);
System.err.println("bvlc_googlenet.caffemodel can be downloaded here:");
System.err.println("http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel");
System.exit(-1);
}
//! [Initialize network]
//! [Prepare blob]
Mat img = imread(imageFile);
if (img.empty()) {
System.err.println("Can't read image from the file: " + imageFile);
System.exit(-1);
}
resize(img, img, new Size(224, 224)); //GoogLeNet accepts only 224x224 RGB-images
Mat inputBlob = blobFromImage(img); //Convert Mat to 4-dimensional dnn::Blob from image
//! [Prepare blob]
//! [Set input blob]
net.setInput(inputBlob, "data", 1.0, null); //set the network input
//! [Set input blob]
//! [Make forward pass]
Mat prob = net.forward("prob"); //compute output
//! [Make forward pass]
//! [Gather output]
Point classId = new Point();
double[] classProb = new double[1];
getMaxClass(prob, classId, classProb);//find the best class
//! [Gather output]
//! [Print results]
List<String> classNames = readClassNames();
System.out.println("Best class: #" + classId.x() + " '" + classNames.get(classId.x()) + "'");
System.out.println("Best class: #" + classId.x());
System.out.println("Probability: " + classProb[0] * 100 + "%");
//! [Print results]
} //main
}