Skip to content

Instantly share code, notes, and snippets.

@michael-iuzzolino
Created April 27, 2019 21:42
Show Gist options
  • Save michael-iuzzolino/dc98554496d918286401f2943d471dfa to your computer and use it in GitHub Desktop.
Save michael-iuzzolino/dc98554496d918286401f2943d471dfa to your computer and use it in GitHub Desktop.
Hello, RNN Data Handler
class DataHandler:
def __init__(self, string):
self.string = string
characters = np.sort(list(set(string)))
self.num_characters = len(characters)
self.char_to_idx = { ch : i for i, ch in enumerate(characters) }
self.idx_to_char = { i : ch for ch, i in self.char_to_idx.items() }
self._process()
def _process(self):
data_torch = torch.tensor([self.make_onehot(ele).data.numpy() for ele in self.string])
self.X = data_torch[:-1].float()
self.y = torch.argmax(data_torch[1:], dim=1).long()
def make_onehot(self, char):
return torch.eye(self.num_characters)[self.char_to_idx[char]].float()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment