Skip to content

Commit

Permalink
Merge pull request #56 from minimaxir/0.5
Browse files Browse the repository at this point in the history
0.5
  • Loading branch information
minimaxir authored May 20, 2019
2 parents be551ab + de2b362 commit 7dc8210
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 44 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ You can use gpt-2-simple to retrain a model using a GPU **for free** in [this Co
gpt-2-simple can be installed [via PyPI](https://pypi.org/project/gpt_2_simple/):

```shell
pip3 install gpt_2_simple
pip3 install gpt-2-simple
```

You will also need to install the corresponding TensorFlow for your system (e.g. `tensorflow` or `tensorflow-gpu`)
Expand Down Expand Up @@ -100,6 +100,8 @@ The method GPT-2 uses to generate text is slightly different than those like oth
* If you pass a single-column `.csv` file to `finetune()`, it will automatically parse the CSV into a format ideal for training with GPT-2 (including prepending `<|startoftext|>` and suffixing `<|endoftext|>` to every text document, so the `truncate` tricks above are helpful when generating output). This is necessary to handle both quotes and newlines in each text document correctly.
* GPT-2 allows you to generate texts in parallel by setting a `batch_size` that is divisible into `nsamples`, resulting in much faster generation. Works very well with a GPU (can set `batch_size` up to 20 on Colaboratory's K80)!
* Due to GPT-2's architecture, it scales up nicely with more powerful GPUs. For the 117M model, if you want to train for longer periods of time, GCP's P100 GPU is about 3x faster than a K80/T4 for only 3x the price, making it price-comparable (the V100 is about 1.5x faster than the P100 but about 2x the price). The P100 uses 100% of the GPU even with `batch_size=1`, and about 88% of the V100 GPU.
* If you have a partially-trained GPT-2 model and want to continue finetuning it, you can set `overwrite=True` to finetune, which will continue training and remove the previous iteration of the model without creating a duplicate copy. This can be especially useful for transfer learning (e.g. heavily finetune GPT-2 on one dataset, then finetune on other dataset to get a "merging" of both datasets).
* If your input text dataset is massive (>100 MB), you may want to preencode and compress the dataset using `gpt2.encode_dataset(file_path)`. THe output is a compressed `.npz` file which will load much faster into the GPU for finetuning.

## Planned Work

Expand Down
113 changes: 81 additions & 32 deletions gpt_2_simple/gpt_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def finetune(sess,
max_checkpoints=1,
use_memory_saving_gradients=False,
only_train_transformer_layers=False,
model_load=False):
overwrite=False):
"""Finetunes the model on the given dataset.
Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
Expand All @@ -105,10 +105,15 @@ def maketree(path):
pass

maketree(checkpoint_path)
if not model_load:
for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
shutil.copyfile(os.path.join('models', model_name, file),
os.path.join(checkpoint_path, file))
files = [f for f in os.listdir(checkpoint_path)]
for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
if file not in files:
try:
shutil.copyfile(os.path.join('models', model_name, file),
os.path.join(checkpoint_path, file))
except FileNotFoundError as fnf_error:
print("You need to download the GPT-2 model first via download_gpt2()")
raise(fnf_error)

enc = encoder.get_encoder(checkpoint_path)
hparams = model.default_hparams()
Expand Down Expand Up @@ -181,9 +186,6 @@ def maketree(path):
print('Loading checkpoint', ckpt)
saver.restore(sess, ckpt)

if model_load:
return

print('Loading dataset...')
chunks = load_dataset(enc, dataset, combine)
data_sampler = Sampler(chunks)
Expand Down Expand Up @@ -236,6 +238,12 @@ def generate_samples():
def sample_batch():
return [data_sampler.sample(1024) for _ in range(batch_size)]

if overwrite and restore_from == 'latest':
for file in files:
if file.startswith('model') or file.startswith('events'):
os.remove(os.path.join(checkpoint_path, file))
save()

avg_loss = (0.0, 0.0)
start_time = time.time()

Expand Down Expand Up @@ -306,19 +314,19 @@ def load_gpt2(sess,


def generate(sess,
run_name='run1',
return_as_list=False,
truncate=None,
destination_path=None,
sample_delim='=' * 20 + '\n',
prefix=None,
model_name='117M',
seed=None,
nsamples=1,
batch_size=1,
length=1023,
temperature=0.7,
top_k=0,
run_name='run1',
top_p=0.0,
include_prefix=True):
"""Generates text from a model loaded into memory.
Expand Down Expand Up @@ -353,7 +361,7 @@ def generate(sess,
start_token=enc.encoder['<|endoftext|>'] if not prefix else None,
context=context if prefix else None,
batch_size=batch_size,
temperature=temperature, top_k=top_k
temperature=temperature, top_k=top_k, top_p=top_p
)[:, 1:]

if destination_path:
Expand Down Expand Up @@ -400,18 +408,18 @@ def generate(sess,


def generate_to_file(sess,
run_name='run1',
truncate=None,
destination_path='gpt_2_gen_texts.txt',
sample_delim='=' * 20 + '\n',
prefix=None,
model_name='117M',
seed=None,
nsamples=1,
batch_size=1,
length=1023,
temperature=0.7,
top_k=0,
run_name='run1',
top_p=0.0,
include_prefix=True):
"""Generates the texts to a file.
Expand All @@ -421,19 +429,19 @@ def generate_to_file(sess,
"""

generate(sess,
run_name,
False,
truncate,
destination_path,
sample_delim,
prefix,
model_name,
seed,
nsamples,
batch_size,
length,
temperature,
top_k,
run_name,
top_p,
include_prefix)


Expand All @@ -456,29 +464,39 @@ def get_tarfile_name(checkpoint_folder):
return tarfile_name


def copy_checkpoint_to_gdrive(checkpoint_folder=os.path.join('checkpoint', 'run1')):
def copy_checkpoint_to_gdrive(run_name='run1', copy_folder=False):
"""Copies the checkpoint folder to a mounted Google Drive."""
is_mounted()

file_path = get_tarfile_name(checkpoint_folder)
checkpoint_folder = os.path.join('checkpoint', run_name)

# Reference: https://stackoverflow.com/a/17081026
with tarfile.open(file_path, 'w') as tar:
tar.add(checkpoint_folder)
if copy_folder:
shutil.copytree(checkpoint_folder, "/content/drive/My Drive/" + checkpoint_folder)
else:
file_path = get_tarfile_name(checkpoint_folder)

shutil.copyfile(file_path, "/content/drive/My Drive/" + file_path)
# Reference: https://stackoverflow.com/a/17081026
with tarfile.open(file_path, 'w') as tar:
tar.add(checkpoint_folder)

shutil.copyfile(file_path, "/content/drive/My Drive/" + file_path)

def copy_checkpoint_from_gdrive(checkpoint_folder=os.path.join('checkpoint', 'run1')):

def copy_checkpoint_from_gdrive(run_name='run1', copy_folder=False):
"""Copies the checkpoint folder from a mounted Google Drive."""
is_mounted()

file_path = get_tarfile_name(checkpoint_folder)
checkpoint_folder = os.path.join('checkpoint', run_name)

shutil.copyfile("/content/drive/My Drive/" + file_path, file_path)
if copy_folder:
shutil.copytree("/content/drive/My Drive/" + checkpoint_folder, checkpoint_folder)
else:
file_path = get_tarfile_name(checkpoint_folder)

with tarfile.open(file_path, 'r') as tar:
tar.extractall()
shutil.copyfile("/content/drive/My Drive/" + file_path, file_path)

with tarfile.open(file_path, 'r') as tar:
tar.extractall()


def copy_file_to_gdrive(file_path):
Expand Down Expand Up @@ -522,6 +540,23 @@ def encode_csv(csv_path, out_path='csv_encoded.txt', header=True,
w.write(start_token + row[0] + end_token + "\n")


def encode_dataset(file_path, out_path='text_encoded.npz',
model_name="117M",
combine=50000):
"""Preencodes a text document into chunks and compresses it,
saving time when generated.
Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/encode.py
"""

model_path = os.path.join('models', model_name)
enc = encoder.get_encoder(model_path)
print('Reading files')
chunks = load_dataset(enc, file_path, combine)
print('Writing', out_path)
np.savez_compressed(out_path, *chunks)


def cmd():
"""Function called when invoking from the terminal."""

Expand Down Expand Up @@ -557,6 +592,9 @@ def cmd():
parser.add_argument(
'--print_every', help="[finetune] After how many steps to print progress",
nargs='?', default=10, type=int)
parser.add_argument(
'--overwrite', help="[finetune] Overwrite existing model when continuing training",
nargs='?', default=False, type=lambda x: (str(x).lower() == 'true'))
parser.add_argument(
'--nfiles', help="[generate] How many files to generate.",
nargs='?', default=1, type=int)
Expand All @@ -572,6 +610,12 @@ def cmd():
parser.add_argument(
'--temperature', help="[generate] Temperature of the generated texts",
nargs='?', default=0.7, type=float)
parser.add_argument(
'--top_k', help="[generate] Sample only from top k tokens",
nargs='?', default=0, type=int)
parser.add_argument(
'--top_p', help="[generate] Sample from top p prob (overrides top_k if nonzero)",
nargs='?', default=0.0, type=float)
parser.add_argument(
'--batch_size', help="[generate] Batch size for generation (increase for GPUs)",
nargs='?', default=1, type=int)
Expand Down Expand Up @@ -604,19 +648,21 @@ def cmd():
steps=args.steps, restore_from=args.restore_from,
sample_every=args.sample_every,
save_every=args.save_every,
print_every=args.print_every)
print_every=args.print_every,
overwrite=args.overwrite)
if args.mode == "generate":
cmd_generate(nfiles=args.nfiles, nsamples=args.nsamples,
folder=args.folder, length=args.length,
temperature=args.temperature, batch_size=args.batch_size,
prefix=args.prefix, truncate=args.truncate,
include_prefix=args.include_prefix,
sample_delim=args.sample_delim, run_name=args.run_name)
sample_delim=args.sample_delim, run_name=args.run_name,
top_k=args.top_k, top_p=args.top_p)


def cmd_finetune(dataset, run_name, model_name, steps,
restore_from, sample_every,
save_every, print_every):
save_every, print_every, overwrite):
"""Wrapper script for finetuning the model via the CLI."""

if not is_gpt2_downloaded(model_name=model_name):
Expand All @@ -627,13 +673,15 @@ def cmd_finetune(dataset, run_name, model_name, steps,
model_name=model_name,
steps=steps, restore_from=restore_from,
sample_every=sample_every, save_every=save_every,
print_every=print_every)
print_every=print_every,
overwrite=overwrite)


def cmd_generate(nfiles, nsamples, folder,
length, temperature, batch_size,
prefix, truncate, include_prefix,
sample_delim, run_name):
sample_delim, run_name,
top_k, top_p):
"""Wrapper script for generating text via the CLI.
The files are generated into a folder, which can be downloaded
recursively by downloading the entire folder.
Expand Down Expand Up @@ -662,5 +710,6 @@ def cmd_generate(nfiles, nsamples, folder,
truncate=truncate,
include_prefix=include_prefix,
sample_delim=sample_delim,
run_name=run_name
top_k=top_k,
top_p=top_p
)
45 changes: 35 additions & 10 deletions gpt_2_simple/src/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from gpt_2_simple.src import model


def top_k_logits(logits, k):
if k == 0:
# no truncation
Expand All @@ -16,25 +17,44 @@ def _top_k():
logits,
)
return tf.cond(
tf.equal(k, 0),
lambda: logits,
lambda: _top_k(),
tf.equal(k, 0),
lambda: logits,
lambda: _top_k(),
)


def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0):
def top_p_logits(logits, p):
with tf.variable_scope('top_p_logits'):
logits_sort = tf.sort(logits, direction='DESCENDING')
probs_sort = tf.nn.softmax(logits_sort)
probs_sums = tf.cumsum(probs_sort, axis=1, exclusive=True)
logits_masked = tf.where(probs_sums < p, logits_sort, tf.ones_like(
logits_sort)*1000) # [batchsize, vocab]
min_logits = tf.reduce_min(logits_masked, axis=1, keepdims=True) # [batchsize, 1]
return tf.where(
logits < min_logits,
tf.ones_like(logits, dtype=logits.dtype) * -1e10,
logits,
)


def sample_sequence(*, hparams, length, start_token=None,
batch_size=None, context=None, temperature=1,
top_k=0, top_p=0.0):
if start_token is None:
assert context is not None, 'Specify exactly one of start_token and context!'
else:
assert context is None, 'Specify exactly one of start_token and context!'
context = tf.fill([batch_size, 1], start_token)

def step(hparams, tokens, past=None):
lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE)
lm_output = model.model(hparams=hparams, X=tokens,
past=past, reuse=tf.AUTO_REUSE)

logits = lm_output['logits'][:, :, :hparams.n_vocab]
presents = lm_output['present']
presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size))
presents.set_shape(model.past_shape(
hparams=hparams, batch_size=batch_size))
return {
'logits': logits,
'presents': presents,
Expand All @@ -48,9 +68,13 @@ def step(hparams, tokens, past=None):

def body(past, prev, output):
next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
logits = top_k_logits(logits, k=top_k)
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
if top_p > 0.0:
logits = top_p_logits(logits, p=top_p)
else:
logits = top_k_logits(logits, k=top_k)
samples = tf.multinomial(
logits, num_samples=1, output_dtype=tf.int32)
return [
tf.concat([past, next_outputs['presents']], axis=-2),
tf.squeeze(samples, axis=[1]),
Expand All @@ -69,7 +93,8 @@ def cond(*args):
context,
],
shape_invariants=[
tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
tf.TensorShape(model.past_shape(
hparams=hparams, batch_size=batch_size)),
tf.TensorShape([batch_size]),
tf.TensorShape([batch_size, None]),
],
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
setup(
name='gpt_2_simple',
packages=['gpt_2_simple'], # this must be the same as the name above
version='0.4.2',
version='0.5',
description="Python package to easily retrain OpenAI's GPT-2 " \
"text-generating model on new texts.",
long_description=long_description,
Expand Down

0 comments on commit 7dc8210

Please sign in to comment.