Last active
February 13, 2019 22:07
-
-
Save alvations/f4a84062fb6464c9d111b6acea552889 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from argparse import Namespace | |
from collections import Counter | |
import json | |
import re | |
import string | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
import torch.optim as optim | |
from torch.utils.data import Dataset, DataLoader | |
from tqdm import tqdm_notebook | |
class Vocabulary(object): | |
"""Class to process text and extract vocabulary for mapping""" | |
def __init__(self, token_to_idx=None): | |
""" | |
Args: | |
token_to_idx (dict): a pre-existing map of tokens to indices | |
""" | |
if token_to_idx is None: | |
token_to_idx = {} | |
self._token_to_idx = token_to_idx | |
self._idx_to_token = {idx: token | |
for token, idx in self._token_to_idx.items()} | |
def to_serializable(self): | |
""" returns a dictionary that can be serialized """ | |
return {'token_to_idx': self._token_to_idx} | |
@classmethod | |
def from_serializable(cls, contents): | |
""" instantiates the Vocabulary from a serialized dictionary """ | |
return cls(**contents) | |
def add_token(self, token): | |
"""Update mapping dicts based on the token. | |
Args: | |
token (str): the item to add into the Vocabulary | |
Returns: | |
index (int): the integer corresponding to the token | |
""" | |
if token in self._token_to_idx: | |
index = self._token_to_idx[token] | |
else: | |
index = len(self._token_to_idx) | |
self._token_to_idx[token] = index | |
self._idx_to_token[index] = token | |
return index | |
def add_many(self, tokens): | |
"""Add a list of tokens into the Vocabulary | |
Args: | |
tokens (list): a list of string tokens | |
Returns: | |
indices (list): a list of indices corresponding to the tokens | |
""" | |
return [self.add_token[token] for token in tokens] | |
def lookup_token(self, token): | |
"""Retrieve the index associated with the token | |
Args: | |
token (str): the token to look up | |
Returns: | |
index (int): the index corresponding to the token | |
""" | |
return self._token_to_idx[token] | |
def lookup_index(self, index): | |
"""Return the token associated with the index | |
Args: | |
index (int): the index to look up | |
Returns: | |
token (str): the token corresponding to the index | |
Raises: | |
KeyError: if the index is not in the Vocabulary | |
""" | |
if index not in self._idx_to_token: | |
raise KeyError("the index (%d) is not in the Vocabulary" % index) | |
return self._idx_to_token[index] | |
def __str__(self): | |
return "<Vocabulary(size=%d)>" % len(self) | |
def __len__(self): | |
return len(self._token_to_idx) | |
class SequenceVocabulary(Vocabulary): | |
def __init__(self, token_to_idx=None, unk_token="<unk>", | |
mask_token="<pad>", begin_seq_token="<s>", | |
end_seq_token="</s>"): | |
super(SequenceVocabulary, self).__init__(token_to_idx) | |
self._mask_token = mask_token | |
self._unk_token = unk_token | |
self._begin_seq_token = begin_seq_token | |
self._end_seq_token = end_seq_token | |
self.mask_index = self.add_token(self._mask_token) | |
self.unk_index = self.add_token(self._unk_token) | |
self.begin_seq_index = self.add_token(self._begin_seq_token) | |
self.end_seq_index = self.add_token(self._end_seq_token) | |
def to_serializable(self): | |
contents = super(SequenceVocabulary, self).to_serializable() | |
contents.update({'unk_token': self._unk_token, | |
'mask_token': self._mask_token, | |
'begin_seq_token': self._begin_seq_token, | |
'end_seq_token': self._end_seq_token}) | |
return contents | |
def lookup_token(self, token): | |
"""Retrieve the index associated with the token | |
or the UNK index if token isn't present. | |
Args: | |
token (str): the token to look up | |
Returns: | |
index (int): the index corresponding to the token | |
Notes: | |
`unk_index` needs to be >=0 (having been added into the Vocabulary) | |
for the UNK functionality | |
""" | |
if self.unk_index >= 0: | |
return self._token_to_idx.get(token, self.unk_index) | |
else: | |
return self._token_to_idx[token] | |
class SurnameVectorizer(object): | |
""" The Vectorizer which coordinates the Vocabularies and puts them to use""" | |
def __init__(self, char_vocab, nationality_vocab): | |
""" | |
Args: | |
char_vocab (Vocabulary): maps words to integers | |
nationality_vocab (Vocabulary): maps nationalities to integers | |
""" | |
self.char_vocab = char_vocab | |
self.nationality_vocab = nationality_vocab | |
def vectorize(self, surname, vector_length=-1): | |
"""Vectorize a surname into a vector of observations and targets | |
The outputs are the vectorized surname split into two vectors: | |
surname[:-1] and surname[1:] | |
At each timestep, the first vector is the observation and the second vector is the target. | |
Args: | |
surname (str): the surname to be vectorized | |
vector_length (int): an argument for forcing the length of index vector | |
Returns: | |
a tuple: (from_vector, to_vector) | |
from_vector (numpy.ndarray): the observation vector | |
to_vector (numpy.ndarray): the target prediction vector | |
""" | |
indices = [self.char_vocab.begin_seq_index] | |
indices.extend(self.char_vocab.lookup_token(token) for token in surname) | |
indices.append(self.char_vocab.end_seq_index) | |
if vector_length < 0: | |
vector_length = len(indices) - 1 | |
from_vector = np.empty(vector_length, dtype=np.int64) | |
from_indices = indices[:-1] | |
from_vector[:len(from_indices)] = from_indices | |
from_vector[len(from_indices):] = self.char_vocab.mask_index | |
to_vector = np.empty(vector_length, dtype=np.int64) | |
to_indices = indices[1:] | |
to_vector[:len(to_indices)] = to_indices | |
to_vector[len(to_indices):] = self.char_vocab.mask_index | |
return from_vector, to_vector | |
@classmethod | |
def from_dataframe(cls, surname_df): | |
"""Instantiate the vectorizer from the dataset dataframe | |
Args: | |
surname_df (pandas.DataFrame): the surname dataset | |
Returns: | |
an instance of the SurnameVectorizer | |
""" | |
char_vocab = SequenceVocabulary() | |
nationality_vocab = Vocabulary() | |
for index, row in surname_df.iterrows(): | |
for char in row.surname: | |
char_vocab.add_token(char) | |
nationality_vocab.add_token(row.nationality) | |
return cls(char_vocab, nationality_vocab) | |
@classmethod | |
def from_serializable(cls, contents): | |
"""Instantiate the vectorizer from saved contents | |
Args: | |
contents (dict): a dict holding two vocabularies for this vectorizer | |
This dictionary is created using `vectorizer.to_serializable()` | |
Returns: | |
an instance of SurnameVectorizer | |
""" | |
char_vocab = SequenceVocabulary.from_serializable(contents['char_vocab']) | |
nat_vocab = Vocabulary.from_serializable(contents['nationality_vocab']) | |
return cls(char_vocab=char_vocab, nationality_vocab=nat_vocab) | |
def to_serializable(self): | |
""" Returns the serializable contents """ | |
return {'char_vocab': self.char_vocab.to_serializable(), | |
'nationality_vocab': self.nationality_vocab.to_serializable()} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment