A TensorFlow implementation of Google's QANet. The starter code is from this repo. Contributions are welcome!
- python 3.6
- numpy 1.14.2
- tensorflow 1.8.0
- tensor2tensor 1.8.0
- colorama 0.3.9
- nltk 3.2.5
- six 1.11.0
- tqdm 4.24.0
To install requirements in a new conda environment, download SQuAD dataset and pretrained GloVe embeddings, and preprocess the data, run
bash get_started.sh
There are 3 modes: train
, show_examples
, and official_eval
. Before entering each mode, make sure to activate the qanet
environment and enter the code/
directory; to do so, run
source activate qanet
cd code
To start training, run (please replace <EXPERIMENT NAME>
by an experiment name, such as qanet
)
python main.py --experiment_name=<EXPERIMENT NAME> --mode=train
Hyperparameters are stored as flags in code/main.py
. Please refer to code/modules.py
and code/qa_model.py
for details.
Training results would be stored under experiments/<EXPERIMENT NAME>
.
To see example output, run
python main.py --experiment_name=<EXPERIMENT NAME> --mode=show_examples
Ten random dev set examples would be printed to screen, comparing the true answer to the model's predicted answer, and giving the F1 and EM score for each example.
To obtain the predictions on the dataset by the model, run
python main.py --experiment_name=<EXPERIMENT NAME> --mode=official_eval \
--json_in_path=../data/dev-v1.1.json \
--ckpt_load_dir=../experiments/<EXPERIMENT NAME>/best_checkpoint
The predictions of the model would be stored in experiments/<EXPERIMENT NAME>/predictions.json
.
To run the official SQuAD evaluation script on the output, run
python evaluate.py ../data/dev-v1.1.json ../experiments/<EXPERIMENT NAME>/predictions.json
To track progress in tensorboard, run
tensorboard --logdir=. --port=8888