A character-level language model using a GRU- or LSTM-based RNN, implemented with PyTorch
No installation is required to use char-parrot itself - simply clone this repository by running git clone https://github.com/cclaypool/char-parrot.git
. However, before using char-parrot, some dependencies must be installed.
If you are using Linux, Python 3 is most likely already installed; if not, install it using your distribution's package manager. For other platforms, go to the Python official website to download and install Python 3. Ensure that the directory where Python 3 is installed is included in the PATH
environment variable on your system.
Once Python is installed, head to the PyTorch official website for information on how to install the latest version of PyTorch.
tqdm is used to display progress bars during training. Install it using pip:
pip install tqdm
Note that pip for Python 3 may be calledpip3
rather than pip
, for example on Linux systems.
Note that you may need to replace python
with python3
in the following commands, if, for example, you are using Linux and have both Python 2 and Python 3 installed.
From the char-parrot directory downloaded with git clone
, run python train.py project_dir [options]
to train a model, and python generate.py project_dir [options]
to generate text based on a previously trained model. Run each script with the --help
flag for detailed information on its usage.
project_dir
must contain a model.ini
model configuration file: see sample_project/model.ini
for a commented example explaining each option.
The model will run on the GPU if available, unless force_cpu
is set to True
in hw.py
.
Train a model based on a configuration stored in project/model.ini
for 20 epochs, saving the model to project/save.pth
after every epoch:
python train.py project -e 20 -s save.pth
Load the saved state project/save.pth
and train for a further 10 epochs, saving the state to project/save.pth
after every epoch:
python train.py project -e 10 -l save.pth -s save.pth
Generate 500 characters of text using the model whose state is saved in project/save.pth
and whose configuration is stored in project/model.ini
, using the seed phrase "once upon a time" and the sampling temperature 0.3:
python generate.py project -l save.pth -n 500 -s "once upon a time" -t 0.3