diff --git a/.gitignore b/.gitignore index a3c82a25b7..545dce4a0e 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,9 @@ __pycache__/ *.pt *.pyc input.txt +*.venv +*.code-workspace +out/ +out-*/ +wandb/* +data/*/samples/* diff --git a/data/shakespeare_char/prepare.py b/data/shakespeare_char/prepare.py index 9fd1621d55..8eb5f8362d 100644 --- a/data/shakespeare_char/prepare.py +++ b/data/shakespeare_char/prepare.py @@ -5,7 +5,7 @@ encoder and decoder and some other related info. """ import os -import pickle +import dill import requests import numpy as np @@ -54,11 +54,11 @@ def decode(l): # save the meta information as well, to help us encode/decode later meta = { 'vocab_size': vocab_size, - 'itos': itos, - 'stoi': stoi, + 'encode': encode, + 'decode': decode, } with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f: - pickle.dump(meta, f) + dill.dump(meta, f, recurse=True) # length of dataset in characters: 1115394 # all the unique characters: diff --git a/sample.py b/sample.py index d25d6e0861..50b49de43c 100644 --- a/sample.py +++ b/sample.py @@ -2,7 +2,7 @@ Sample from a trained model """ import os -import pickle +import dill from contextlib import nullcontext import torch import tiktoken @@ -61,11 +61,8 @@ if load_meta: print(f"Loading meta from {meta_path}...") with open(meta_path, 'rb') as f: - meta = pickle.load(f) - # TODO want to make this more general to arbitrary encoder/decoder schemes - stoi, itos = meta['stoi'], meta['itos'] - encode = lambda s: [stoi[c] for c in s] - decode = lambda l: ''.join([itos[i] for i in l]) + meta = dill.load(f) + encode, decode = meta['encode'], meta['decode'] else: # ok let's assume gpt-2 encodings by default print("No meta.pkl found, assuming GPT-2 encodings...") diff --git a/train.py b/train.py index a482ab7f4e..cf3454b575 100644 --- a/train.py +++ b/train.py @@ -19,7 +19,7 @@ import os import time import math -import pickle +import dill from contextlib import nullcontext import numpy as np @@ -136,7 +136,7 @@ def get_batch(split): meta_vocab_size = None if os.path.exists(meta_path): with open(meta_path, 'rb') as f: - meta = pickle.load(f) + meta = dill.load(f) meta_vocab_size = meta['vocab_size'] print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")