Skip to content

Instantly share code, notes, and snippets.

@pbloem
Created September 2, 2021 14:33
Show Gist options
  • Save pbloem/2c3af77626d6c80f62487c35a28e3e8c to your computer and use it in GitHub Desktop.
Save pbloem/2c3af77626d6c80f62487c35a28e3e8c to your computer and use it in GitHub Desktop.
Data loaders for DLVU assignment 3B (recurrent neural nets)
import wget, os, gzip, pickle, random, re, sys
IMDB_URL = 'http://dlvu.github.io/data/imdb.{}.pkl.gz'
IMDB_FILE = 'imdb.{}.pkl.gz'
PAD, START, END, UNK = '.pad', '.start', '.end', '.unk'
def load_imdb(final=False, val=5000, seed=0, voc=None, char=False):
cst = 'char' if char else 'word'
imdb_url = IMDB_URL.format(cst)
imdb_file = IMDB_FILE.format(cst)
if not os.path.exists(imdb_file):
wget.download(imdb_url)
with gzip.open(imdb_file) as file:
sequences, labels, i2w, w2i = pickle.load(file)
if voc is not None and voc < len(i2w):
nw_sequences = {}
i2w = i2w[:voc]
w2i = {w: i for i, w in enumerate(i2w)}
mx, unk = voc, w2i['.unk']
for key, seqs in sequences.items():
nw_sequences[key] = []
for seq in seqs:
seq = [s if s < mx else unk for s in seq]
nw_sequences[key].append(seq)
sequences = nw_sequences
if final:
return (sequences['train'], labels['train']), (sequences['test'], labels['test']), (i2w, w2i), 2
# Make a validation split
random.seed(seed)
x_train, y_train = [], []
x_val, y_val = [], []
val_ind = set( random.sample(range(len(sequences['train'])), k=val) )
for i, (s, l) in enumerate(zip(sequences['train'], labels['train'])):
if i in val_ind:
x_val.append(s)
y_val.append(l)
else:
x_train.append(s)
y_train.append(l)
return (x_train, y_train), \
(x_val, y_val), \
(i2w, w2i), 2
def gen_sentence(sent, g):
symb = '_[a-z]*'
while True:
match = re.search(symb, sent)
if match is None:
return sent
s = match.span()
sent = sent[:s[0]] + random.choice(g[sent[s[0]:s[1]]]) + sent[s[1]:]
def gen_dyck(p):
open = 1
sent = '('
while open > 0:
if random.random() < p:
sent += '('
open += 1
else:
sent += ')'
open -= 1
return sent
def gen_ndfa(p):
word = random.choice(['abc!', 'uvw!', 'klm!'])
s = ''
while True:
if random.random() < p:
return 's' + s + 's'
else:
s+= word
def load_brackets(n=50_000, seed=0):
return load_toy(n, char=True, seed=seed, name='dyck')
def load_ndfa(n=50_000, seed=0):
return load_toy(n, char=True, seed=seed, name='ndfa')
def load_toy(n=50_000, char=True, seed=0, name='lang'):
random.seed(0)
if name == 'lang':
sent = '_s'
toy = {
'_s': ['_s _adv', '_np _vp', '_np _vp _prep _np', '_np _vp ( _prep _np )', '_np _vp _con _s' , '_np _vp ( _con _s )'],
'_adv': ['briefly', 'quickly', 'impatiently'],
'_np': ['a _noun', 'the _noun', 'a _adj _noun', 'the _adj _noun'],
'_prep': ['on', 'with', 'to'],
'_con' : ['while', 'but'],
'_noun': ['mouse', 'bunny', 'cat', 'dog', 'man', 'woman', 'person'],
'_vp': ['walked', 'walks', 'ran', 'runs', 'goes', 'went'],
'_adj': ['short', 'quick', 'busy', 'nice', 'gorgeous']
}
sentences = [ gen_sentence(sent, toy) for _ in range(n)]
sentences.sort(key=lambda s : len(s))
elif name == 'dyck':
sentences = [gen_dyck(7./16.) for _ in range(n)]
sentences.sort(key=lambda s: len(s))
elif name == 'ndfa':
sentences = [gen_ndfa(1./4.) for _ in range(n)]
sentences.sort(key=lambda s: len(s))
else:
raise Exception(name)
tokens = set()
for s in sentences:
if char:
for c in s:
tokens.add(c)
else:
for w in s.split():
tokens.add(w)
i2t = [PAD, START, END, UNK] + list(tokens)
t2i = {t:i for i, t in enumerate(i2t)}
sequences = []
for s in sentences:
if char:
tok = list(s)
else:
tok = s.split()
sequences.append([t2i[t] for t in tok])
return sequences, (i2t, t2i)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment