Skip to content

Instantly share code, notes, and snippets.

@edumunozsala
Last active September 25, 2020 18:18
Show Gist options
  • Save edumunozsala/3d9e9e55455cafea3a6580e6106d6cde to your computer and use it in GitHub Desktop.
Save edumunozsala/3d9e9e55455cafea3a6580e6106d6cde to your computer and use it in GitHub Desktop.
Predict functions for CLTG
def sample_from_probs(probs, top_n=10):
"""
truncated weighted random choice.
"""
_, indices = torch.sort(probs)
# set probabilities after top_n to 0
probs[indices.data[:-top_n]] = 0
# Sampling the index of the predicted next char
sampled_index = torch.multinomial(probs, 1)
return sampled_index
def predict_probs(model, hidden, character, vocab, device):
# One-hot encoding our input to fit into the model
character = np.array([[vocab[c] for c in character]])
character = one_hot_encode(character, len(vocab))
character = torch.from_numpy(character)
character = character.to(device)
with torch.no_grad():
# Forward pass through the model
out, hidden = model(character, hidden)
# Return the logits
prob = nn.functional.softmax(out[-1], dim=0).data
return prob, hidden
def predict_fn(input_data, model):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model.char2int is None:
raise Exception('Model has not been loaded properly, no word_dict.')
# Extract the input data and the desired length
out_len, start = input_data
out_len = int(out_len)
model.eval() # eval mode
start = start.lower()
# Clean the text as the text used in training
start = clean_text(start, True)
# First off, run through the starting characters
chars = [ch for ch in start]
size = out_len - len(chars)
# Init the hidden state
state = model.init_state(device, 1)
# Warm up the initial state, predicting on the initial string
for ch in chars:
#char, state = predict(model, ch, state, top_n=top_k)
probs, state = predict_probs(model, state, ch, model.char2int, device)
next_index = sample_from_probs(probs, 5)
# Include the last char predicted to the predicted output
chars.append(model.int2char[next_index.data[0]])
# Now pass in the previous characters and get a new one
for ii in range(size-1):
#char, h = predict_char(model, chars, vocab)
probs, state = predict_probs(model, state, chars[-1], model.char2int, device)
next_index = sample_from_probs(probs, 5)
# append to sequence
chars.append(model.int2char[next_index.data[0]])
# Join all the chars
#chars = chars.decode('utf-8')
return ''.join(chars)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment