Skip to content

Instantly share code, notes, and snippets.

@jvns
Created November 19, 2020 21:48
Show Gist options
  • Select an option

  • Save jvns/ce24d85c9d645c32b0d552b69359979f to your computer and use it in GitHub Desktop.

Select an option

Save jvns/ce24d85c9d645c32b0d552b69359979f to your computer and use it in GitHub Desktop.
import tsensor
from fastai.text import *
import unidecode
import string
all_characters = string.printable
n_characters = len(all_characters)
file = unidecode.unidecode(open('t8.shakespeare.txt').read())
file = file[20000:]
class CharTokenizer(BaseTokenizer):
def __init__(self, lang:str='no_lang'):
'''Needed to initialize BaseTokenizer correctly.'''
super().__init__(lang=lang)
def tokenizer(self, t:str) -> List[str]:
'''Turns each character into a token. Replaces spaces with '_'.'''
return list(t.replace(' ', '_'))
char_tokenize_processor = TokenizeProcessor(tokenizer=Tokenizer(tok_func=CharTokenizer), include_bos=False)
n_train=3000000
n_test=100000
spaces = ' '.join(file[:n_train].split())
train=TextList((x for x in spaces), processor=[char_tokenize_processor, NumericalizeProcessor(max_vocab=30000)])
test=TextList((x for x in file[n_train:n_train + n_test]), processor=[char_tokenize_processor, NumericalizeProcessor(max_vocab=30000)])
src = ItemLists(train=train, valid=test, path='test').label_for_lm()
data = src.databunch(bs=77)
v = data.valid_ds.vocab
nv = len(v.itos); print(f"nv: {nv}")
nh = 87
bs = 78
x,y = [(x,y) for x,y in data.train_dl][0]
class RNN(nn.Module):
def __init__(self):
super().__init__()
self.i2h = nn.Linear(nv, nh) # Wxh
self.h2h = nn.Linear(nh, nh) # Whh
self.h2o = nn.Linear(nh, nv) # Why
self.hidden = torch.zeros(1, nh).cuda()
def forward(self, input):
x = self.i2h(torch.nn.functional.one_hot(input, num_classes=nv).type(torch.FloatTensor).cuda())
y = self.h2h(self.hidden)
hidden = torch.tanh(y + x)
self.hidden = hidden.detach()
z = self.h2o(hidden)
return z
def my_loss(input, target):
target = target.flatten()
input = input.view(-1, input.size(2))
return F.cross_entropy(input, target)
learn = Learner(data, RNN(), metrics=my_loss)
learn.fit_one_cycle(60, .1)
temperature = 1
prediction_vector = F.softmax(learn.model(x)[0]/temperature)
v.textify(torch.multinomial(prediction_vector, 1).flatten(), sep='')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment