This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load(self, name: str) -> None: | |
with open(f'./{name}/vocabulary.txt', 'r') as f: | |
vocabulary = f.read().split('[separator]') | |
with open(f'./{name}/inter_time_step_size.txt', 'r') as f: | |
inter_time_step_size = int(f.read()) | |
with open(f'./{name}/unit_type.txt', 'r') as f: | |
unit_type = f.read() | |
with open(f'./{name}/depth.txt', 'r') as f: | |
depth = int(f.read()) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def save(self, name: str) -> None: | |
mkdir(f'./{name}') | |
mkdir(f'./{name}/weights') | |
with open(f'./{name}/vocabulary.txt', 'w') as f: | |
f.write('[separator]'.join(self.vocab)) | |
with open(f'./{name}/inter_time_step_size.txt', 'w') as f: | |
f.write(str(self.inter_time_step_size)) | |
with open(f'./{name}/unit_type.txt', 'w') as f: | |
f.write(self.unit_type) | |
with open(f'./{name}/depth.txt', 'w') as f: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def predict_next(self, sentence: str, | |
threshold: float = 0.9) -> str: | |
# predict the next part of the sentence given as parameter | |
self.reset_state(1) | |
for word in sentence.strip(): | |
x = words2onehot(self.vocab, [word]) | |
y_hat = self(x) | |
s = '' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def sample(self, threshold: float = 0.9) -> str: | |
# sample a new sentence from the learned model | |
sentence = '' | |
self.reset_state(1) | |
x = np.zeros((1, self.vocab_size)) | |
while True: | |
y_hat = self(x) | |
word = sample_word(self.vocab, | |
tf.reshape(y_hat, (-1,)).numpy(), | |
threshold) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def fit(self, | |
sentences: list, | |
batch_size: int = 128, | |
epochs: int = 10) -> None: | |
n_sent = len(sentences) | |
num_batches = ceil(n_sent / batch_size) | |
for epoch in range(epochs): | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _call_lstm(self, | |
level: int, | |
x: Union[np.ndarray, tf.Tensor]) -> tf.Tensor: | |
n = x.shape[0] | |
self.a[level] = self.a[level][0:n] | |
self.c[level] = self.c[level][0:n] | |
concat_matrix = tf.concat([self.a[level], x], axis=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _call_gru(self, | |
level: int, | |
x: Union[np.ndarray, tf.Tensor]) -> tf.Tensor: | |
n = x.shape[0] | |
self.a[level] = self.a[level][0:n] | |
concat_matrix = tf.concat([self.a[level], x], axis=1) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _call_level(self, | |
level: int, | |
x: Union[np.ndarray, tf.Tensor]) -> tf.Tensor: | |
return (self._call_gru(level, x) if self.unit_type == 'gru' | |
else self._call_lstm(level, x)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def __call__(self, | |
x: Union[np.ndarray, tf.Tensor], | |
y: Union[np.ndarray, tf.Tensor, None] = None) -> tf.Tensor: | |
for i in range(self.depth): | |
x = self._call_level(i, x) | |
y_logits = tf.linalg.matmul(x, self.wy)+self.by | |
if y is None: | |
# during prediction return softmax probabilities |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def reset_state(self, num_samples: int) -> None: | |
def get_init_values(): | |
return [tf.zeros((num_samples, self.inter_time_step_size), | |
dtype=tf.double) for i in range(self.depth)] | |
self.a = get_init_values() | |
if self.unit_type == 'lstm': | |
self.c = get_init_values() |
NewerOlder