Skip to content

Instantly share code, notes, and snippets.

@titipata
Last active July 26, 2018 02:03
Show Gist options
  • Save titipata/2729b65e05089666d48252f13c1b0788 to your computer and use it in GitHub Desktop.
Save titipata/2729b65e05089666d48252f13c1b0788 to your computer and use it in GitHub Desktop.
import os
import json
import torch
import deepcut
import numpy as np

# load model and corpus
device = torch.device("cpu")
with open('./thai-song-model.pt', 'rb') as f:
    model = torch.load(f, map_location=lambda storage, loc: storage).to(device)
model.eval()

with open('./corpus_lyrics.json', mode='rb') as f:
    corpus = json.load(f)
corpus['dictionary_reverse'] = {int(k): v for k, v in corpus['dictionary_reverse'].items()}
ntokens = len(corpus['dictionary'])

def predict_next_lyrics(input_text, num_word=150, temperature=0.8):
    
    hidden = model.init_hidden(1)
    input_ids = [corpus['dictionary'].get(word, 0) for word in deepcut.tokenize(input_text)]

    for input_id in input_ids:
        input_id = torch.from_numpy(np.array([[input_id]])).to(device)
        output, hidden = model(input_id, hidden)

    temperature = 1.0
    with torch.no_grad():  # no tracking history
        lyric = ""
        for i in range(int(num_word)):
            output, hidden = model(input, hidden)
            word_weights = output.squeeze().div(temperature).exp().cpu()
            word_idx = torch.multinomial(word_weights, 1)[0]
            input.fill_(word_idx)
            word = corpus['dictionary_reverse'].get(int(word_idx), '')
            lyric += word + ('\n' if i % 20 == 19 else ' ')
    return input_text + lyric

if __name__ == '__main__':
    predict_next_lyrics('ยังคง')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment