Skip to content

Commit

Permalink
Add IMDB LSTM dataset example
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Dec 7, 2017
1 parent dc05d62 commit dd57ebe
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
64 changes: 64 additions & 0 deletions docs/examples/classifying-imdb.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# The IMDB Dataset

As of [ai.stanford.edu](http://ai.stanford.edu/~amaas/data/sentiment/):

> This is a dataset for binary sentiment classification containing substantially more data than previous benchmark datasets. We provide a set of 25,000 highly polar movie reviews for training, and 25,000 for testing. There is additional unlabeled data for use as well. Raw text and already processed bag of words formats are provided. See the README file contained in the release for more details.

The full IMDB dataset in the ARFF format can be found [here](https://github.com/Waikato/wekaDeeplearning4j/blob/develop/package/src/test/resources/nominal/imdb.arff).

## Java
The following code builds a network consisting of an LSTM layer and an RnnOutputLayer, loading imdb reviews and mapping them into a sequence of vectors in the embedding space that is defined by the Google News model. Furthermore, gradient clipping at a value of 1.0 is applied to prevent the network from exploding gradients.

```java
// Download e.g the SLIM Google News model from
// https://github.com/eyaler/word2vec-slim/raw/master/GoogleNews-vectors-negative300-SLIM.bin.gz
final File modelSlim = new File("path/to/google/news/model");

// Setup hyperparameters
final int truncateLength = 80;
final int batchSize = 64;
final int seed = 1;
final int numEpochs = 10;
final int tbpttLength = 20;
final double l2 = 1e-5;
final double gradientThreshold = 1.0;
final double learningRate = 0.02;

// Setup the iterator
TextEmbeddingInstanceIterator tii = new TextEmbeddingInstanceIterator();
tii.setWordVectorLocation(modelSlim);
tii.setTruncateLength(truncateLength);
tii.setTrainBatchSize(batchSize);

// Initialize the classifier
RnnSequenceClassifier clf = new RnnSequenceClassifier();
clf.setSeed(seed);
clf.setNumEpochs(numEpochs);
clf.setInstanceIterator(tii);
clf.settBPTTbackwardLength(tbpttLength);
clf.settBPTTforwardLength(tbpttLength);

// Define the layers
LSTM lstm = new LSTM();
lstm.setNOut(64);
lstm.setActivationFunction(new ActivationTanH());

RnnOutputLayer rnnOut = new RnnOutputLayer();

// Network config
NeuralNetConfiguration nnc = new NeuralNetConfiguration();
nnc.setL2(l2);
nnc.setUseRegularization(true);
nnc.setGradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue);
nnc.setGradientNormalizationThreshold(gradientThreshold);
nnc.setLearningRate(learningRate);

// Config classifier
clf.setLayers(lstm, rnnOut);
clf.setNeuralNetConfiguration(nnc);
Instances data = new Instances(new FileReader("src/test/resources/nominal/imdb.arff"));
data.setClassIndex(1);
clf.buildClassifier(data);
```

1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ pages:
- Examples:
- 'Classifying the Iris Dataset': 'examples/classifying-iris.md'
- 'Classifying the MNIST Dataset': 'examples/classifying-mnist.md'
- 'Classifying the IMDB Dataset': 'examples/classifying-imdb.md'
- Troubleshooting: troubleshooting.md

0 comments on commit dd57ebe

Please sign in to comment.