-
Notifications
You must be signed in to change notification settings - Fork 205
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dc05d62
commit dd57ebe
Showing
2 changed files
with
65 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# The IMDB Dataset | ||
|
||
As of [ai.stanford.edu](http://ai.stanford.edu/~amaas/data/sentiment/): | ||
|
||
> This is a dataset for binary sentiment classification containing substantially more data than previous benchmark datasets. We provide a set of 25,000 highly polar movie reviews for training, and 25,000 for testing. There is additional unlabeled data for use as well. Raw text and already processed bag of words formats are provided. See the README file contained in the release for more details. | ||
|
||
The full IMDB dataset in the ARFF format can be found [here](https://github.com/Waikato/wekaDeeplearning4j/blob/develop/package/src/test/resources/nominal/imdb.arff). | ||
|
||
## Java | ||
The following code builds a network consisting of an LSTM layer and an RnnOutputLayer, loading imdb reviews and mapping them into a sequence of vectors in the embedding space that is defined by the Google News model. Furthermore, gradient clipping at a value of 1.0 is applied to prevent the network from exploding gradients. | ||
|
||
```java | ||
// Download e.g the SLIM Google News model from | ||
// https://github.com/eyaler/word2vec-slim/raw/master/GoogleNews-vectors-negative300-SLIM.bin.gz | ||
final File modelSlim = new File("path/to/google/news/model"); | ||
|
||
// Setup hyperparameters | ||
final int truncateLength = 80; | ||
final int batchSize = 64; | ||
final int seed = 1; | ||
final int numEpochs = 10; | ||
final int tbpttLength = 20; | ||
final double l2 = 1e-5; | ||
final double gradientThreshold = 1.0; | ||
final double learningRate = 0.02; | ||
|
||
// Setup the iterator | ||
TextEmbeddingInstanceIterator tii = new TextEmbeddingInstanceIterator(); | ||
tii.setWordVectorLocation(modelSlim); | ||
tii.setTruncateLength(truncateLength); | ||
tii.setTrainBatchSize(batchSize); | ||
|
||
// Initialize the classifier | ||
RnnSequenceClassifier clf = new RnnSequenceClassifier(); | ||
clf.setSeed(seed); | ||
clf.setNumEpochs(numEpochs); | ||
clf.setInstanceIterator(tii); | ||
clf.settBPTTbackwardLength(tbpttLength); | ||
clf.settBPTTforwardLength(tbpttLength); | ||
|
||
// Define the layers | ||
LSTM lstm = new LSTM(); | ||
lstm.setNOut(64); | ||
lstm.setActivationFunction(new ActivationTanH()); | ||
|
||
RnnOutputLayer rnnOut = new RnnOutputLayer(); | ||
|
||
// Network config | ||
NeuralNetConfiguration nnc = new NeuralNetConfiguration(); | ||
nnc.setL2(l2); | ||
nnc.setUseRegularization(true); | ||
nnc.setGradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue); | ||
nnc.setGradientNormalizationThreshold(gradientThreshold); | ||
nnc.setLearningRate(learningRate); | ||
|
||
// Config classifier | ||
clf.setLayers(lstm, rnnOut); | ||
clf.setNeuralNetConfiguration(nnc); | ||
Instances data = new Instances(new FileReader("src/test/resources/nominal/imdb.arff")); | ||
data.setClassIndex(1); | ||
clf.buildClassifier(data); | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters