Skip to content

Instantly share code, notes, and snippets.

Avatar
🎯
Focusing

Prateek Joshi prateekjoshi565

🎯
Focusing
View GitHub Profile
View nlg_text_gen.py
# predict next token
def predict(net, tkn, h=None):
# tensor inputs
x = np.array([[token2int[tkn]]])
inputs = torch.from_numpy(x)
# push to GPU
inputs = inputs.cuda()
View nlg_train.py
def train(net, epochs=10, batch_size=32, lr=0.001, clip=1, print_every=32):
# optimizer
opt = torch.optim.Adam(net.parameters(), lr=lr)
# loss
criterion = nn.CrossEntropyLoss()
# push model to GPU
net.cuda()
View nlg_instantiate_model.py
# instantiate the model
net = WordLSTM()
# push the model to GPU (avoid it if you are not using the GPU)
net.cuda()
print(net)
View nlg_model_arch.py
class WordLSTM(nn.Module):
def __init__(self, n_hidden=256, n_layers=4, drop_prob=0.3, lr=0.001):
super().__init__()
self.drop_prob = drop_prob
self.n_layers = n_layers
self.n_hidden = n_hidden
self.lr = lr
View nlg_get_batch.py
def get_batches(arr_x, arr_y, batch_size):
# iterate through the arrays
prv = 0
for n in range(batch_size, arr_x.shape[0], batch_size):
x = arr_x[prv:n,:]
y = arr_y[prv:n,:]
prv = n
yield x, y
View nlg_text_to_int.py
def get_integer_seq(seq):
return [token2int[w] for w in seq.split()]
# convert text sequences to integer sequences
x_int = [get_integer_seq(i) for i in x]
y_int = [get_integer_seq(i) for i in y]
# convert lists to numpy arrays
x_int = np.array(x_int)
y_int = np.array(y_int)
View nlg_tok_int_mapping.py
# create integer-to-token mapping
int2token = {}
cnt = 0
for w in set(" ".join(movie_plots).split()):
int2token[cnt] = w
cnt+= 1
# create token-to-integer mapping
token2int = {t: i for i, t in int2token.items()}
View nlg_ip_op.py
# create inputs and targets (x and y)
x = []
y = []
for s in seqs:
x.append(" ".join(s.split()[:-1]))
y.append(" ".join(s.split()[1:]))
View nlg_get_seqs.py
seqs = [create_seq(i) for i in movie_plots]
# merge list-of-lists into a single list
seqs = sum(seqs, [])
# count of sequences
len(seqs)
View nlg_seq_prep_func.py
# create sequences of length 5 tokens
def create_seq(text, seq_len = 5):
sequences = []
# if the number of tokens in 'text' is greater than 5
if len(text.split()) > seq_len:
for i in range(seq_len, len(text.split())):
# select sequence of tokens
seq = text.split()[i-seq_len:i+1]
You can’t perform that action at this time.