Skip to content

Instantly share code, notes, and snippets.

@alvations
Last active February 13, 2019 22:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alvations/f4a84062fb6464c9d111b6acea552889 to your computer and use it in GitHub Desktop.
Save alvations/f4a84062fb6464c9d111b6acea552889 to your computer and use it in GitHub Desktop.
import os
from argparse import Namespace
from collections import Counter
import json
import re
import string
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm_notebook
class Vocabulary(object):
"""Class to process text and extract vocabulary for mapping"""
def __init__(self, token_to_idx=None):
"""
Args:
token_to_idx (dict): a pre-existing map of tokens to indices
"""
if token_to_idx is None:
token_to_idx = {}
self._token_to_idx = token_to_idx
self._idx_to_token = {idx: token
for token, idx in self._token_to_idx.items()}
def to_serializable(self):
""" returns a dictionary that can be serialized """
return {'token_to_idx': self._token_to_idx}
@classmethod
def from_serializable(cls, contents):
""" instantiates the Vocabulary from a serialized dictionary """
return cls(**contents)
def add_token(self, token):
"""Update mapping dicts based on the token.
Args:
token (str): the item to add into the Vocabulary
Returns:
index (int): the integer corresponding to the token
"""
if token in self._token_to_idx:
index = self._token_to_idx[token]
else:
index = len(self._token_to_idx)
self._token_to_idx[token] = index
self._idx_to_token[index] = token
return index
def add_many(self, tokens):
"""Add a list of tokens into the Vocabulary
Args:
tokens (list): a list of string tokens
Returns:
indices (list): a list of indices corresponding to the tokens
"""
return [self.add_token[token] for token in tokens]
def lookup_token(self, token):
"""Retrieve the index associated with the token
Args:
token (str): the token to look up
Returns:
index (int): the index corresponding to the token
"""
return self._token_to_idx[token]
def lookup_index(self, index):
"""Return the token associated with the index
Args:
index (int): the index to look up
Returns:
token (str): the token corresponding to the index
Raises:
KeyError: if the index is not in the Vocabulary
"""
if index not in self._idx_to_token:
raise KeyError("the index (%d) is not in the Vocabulary" % index)
return self._idx_to_token[index]
def __str__(self):
return "<Vocabulary(size=%d)>" % len(self)
def __len__(self):
return len(self._token_to_idx)
class SequenceVocabulary(Vocabulary):
def __init__(self, token_to_idx=None, unk_token="<unk>",
mask_token="<pad>", begin_seq_token="<s>",
end_seq_token="</s>"):
super(SequenceVocabulary, self).__init__(token_to_idx)
self._mask_token = mask_token
self._unk_token = unk_token
self._begin_seq_token = begin_seq_token
self._end_seq_token = end_seq_token
self.mask_index = self.add_token(self._mask_token)
self.unk_index = self.add_token(self._unk_token)
self.begin_seq_index = self.add_token(self._begin_seq_token)
self.end_seq_index = self.add_token(self._end_seq_token)
def to_serializable(self):
contents = super(SequenceVocabulary, self).to_serializable()
contents.update({'unk_token': self._unk_token,
'mask_token': self._mask_token,
'begin_seq_token': self._begin_seq_token,
'end_seq_token': self._end_seq_token})
return contents
def lookup_token(self, token):
"""Retrieve the index associated with the token
or the UNK index if token isn't present.
Args:
token (str): the token to look up
Returns:
index (int): the index corresponding to the token
Notes:
`unk_index` needs to be >=0 (having been added into the Vocabulary)
for the UNK functionality
"""
if self.unk_index >= 0:
return self._token_to_idx.get(token, self.unk_index)
else:
return self._token_to_idx[token]
class SurnameVectorizer(object):
""" The Vectorizer which coordinates the Vocabularies and puts them to use"""
def __init__(self, char_vocab, nationality_vocab):
"""
Args:
char_vocab (Vocabulary): maps words to integers
nationality_vocab (Vocabulary): maps nationalities to integers
"""
self.char_vocab = char_vocab
self.nationality_vocab = nationality_vocab
def vectorize(self, surname, vector_length=-1):
"""Vectorize a surname into a vector of observations and targets
The outputs are the vectorized surname split into two vectors:
surname[:-1] and surname[1:]
At each timestep, the first vector is the observation and the second vector is the target.
Args:
surname (str): the surname to be vectorized
vector_length (int): an argument for forcing the length of index vector
Returns:
a tuple: (from_vector, to_vector)
from_vector (numpy.ndarray): the observation vector
to_vector (numpy.ndarray): the target prediction vector
"""
indices = [self.char_vocab.begin_seq_index]
indices.extend(self.char_vocab.lookup_token(token) for token in surname)
indices.append(self.char_vocab.end_seq_index)
if vector_length < 0:
vector_length = len(indices) - 1
from_vector = np.empty(vector_length, dtype=np.int64)
from_indices = indices[:-1]
from_vector[:len(from_indices)] = from_indices
from_vector[len(from_indices):] = self.char_vocab.mask_index
to_vector = np.empty(vector_length, dtype=np.int64)
to_indices = indices[1:]
to_vector[:len(to_indices)] = to_indices
to_vector[len(to_indices):] = self.char_vocab.mask_index
return from_vector, to_vector
@classmethod
def from_dataframe(cls, surname_df):
"""Instantiate the vectorizer from the dataset dataframe
Args:
surname_df (pandas.DataFrame): the surname dataset
Returns:
an instance of the SurnameVectorizer
"""
char_vocab = SequenceVocabulary()
nationality_vocab = Vocabulary()
for index, row in surname_df.iterrows():
for char in row.surname:
char_vocab.add_token(char)
nationality_vocab.add_token(row.nationality)
return cls(char_vocab, nationality_vocab)
@classmethod
def from_serializable(cls, contents):
"""Instantiate the vectorizer from saved contents
Args:
contents (dict): a dict holding two vocabularies for this vectorizer
This dictionary is created using `vectorizer.to_serializable()`
Returns:
an instance of SurnameVectorizer
"""
char_vocab = SequenceVocabulary.from_serializable(contents['char_vocab'])
nat_vocab = Vocabulary.from_serializable(contents['nationality_vocab'])
return cls(char_vocab=char_vocab, nationality_vocab=nat_vocab)
def to_serializable(self):
""" Returns the serializable contents """
return {'char_vocab': self.char_vocab.to_serializable(),
'nationality_vocab': self.nationality_vocab.to_serializable()}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment