Last active
September 25, 2020 18:18
-
-
Save edumunozsala/3d9e9e55455cafea3a6580e6106d6cde to your computer and use it in GitHub Desktop.
Predict functions for CLTG
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def sample_from_probs(probs, top_n=10): | |
""" | |
truncated weighted random choice. | |
""" | |
_, indices = torch.sort(probs) | |
# set probabilities after top_n to 0 | |
probs[indices.data[:-top_n]] = 0 | |
# Sampling the index of the predicted next char | |
sampled_index = torch.multinomial(probs, 1) | |
return sampled_index | |
def predict_probs(model, hidden, character, vocab, device): | |
# One-hot encoding our input to fit into the model | |
character = np.array([[vocab[c] for c in character]]) | |
character = one_hot_encode(character, len(vocab)) | |
character = torch.from_numpy(character) | |
character = character.to(device) | |
with torch.no_grad(): | |
# Forward pass through the model | |
out, hidden = model(character, hidden) | |
# Return the logits | |
prob = nn.functional.softmax(out[-1], dim=0).data | |
return prob, hidden | |
def predict_fn(input_data, model): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if model.char2int is None: | |
raise Exception('Model has not been loaded properly, no word_dict.') | |
# Extract the input data and the desired length | |
out_len, start = input_data | |
out_len = int(out_len) | |
model.eval() # eval mode | |
start = start.lower() | |
# Clean the text as the text used in training | |
start = clean_text(start, True) | |
# First off, run through the starting characters | |
chars = [ch for ch in start] | |
size = out_len - len(chars) | |
# Init the hidden state | |
state = model.init_state(device, 1) | |
# Warm up the initial state, predicting on the initial string | |
for ch in chars: | |
#char, state = predict(model, ch, state, top_n=top_k) | |
probs, state = predict_probs(model, state, ch, model.char2int, device) | |
next_index = sample_from_probs(probs, 5) | |
# Include the last char predicted to the predicted output | |
chars.append(model.int2char[next_index.data[0]]) | |
# Now pass in the previous characters and get a new one | |
for ii in range(size-1): | |
#char, h = predict_char(model, chars, vocab) | |
probs, state = predict_probs(model, state, chars[-1], model.char2int, device) | |
next_index = sample_from_probs(probs, 5) | |
# append to sequence | |
chars.append(model.int2char[next_index.data[0]]) | |
# Join all the chars | |
#chars = chars.decode('utf-8') | |
return ''.join(chars) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment