Skip to content

Instantly share code, notes, and snippets.

@lazuxd
Created July 1, 2021 13:04
Show Gist options
  • Save lazuxd/59e874343e8fa1878e1251ef5fa48f77 to your computer and use it in GitHub Desktop.
Save lazuxd/59e874343e8fa1878e1251ef5fa48f77 to your computer and use it in GitHub Desktop.
def save(self, name: str) -> None:
mkdir(f'./{name}')
with open(f'./{name}/vocabulary.txt', 'w') as f:
f.write(','.join(self.vocab))
with open(f'./{name}/a_size.txt', 'w') as f:
f.write(str(self.a_size))
np.save(f'./{name}/wa.npy', self.wa.numpy())
np.save(f'./{name}/ba.npy', self.ba.numpy())
np.save(f'./{name}/wy.npy', self.wy.numpy())
np.save(f'./{name}/by.npy', self.by.numpy())
def load(self, name: str) -> None:
with open(f'./{name}/vocabulary.txt', 'r') as f:
self.vocab = f.read().split(',')
with open(f'./{name}/a_size.txt', 'r') as f:
self.a_size = int(f.read())
self.vocab_size = len(self.vocab)
self.combined_size = self.vocab_size + self.a_size
self.wa = tf.Variable(np.load(f'./{name}/wa.npy'))
self.ba = tf.Variable(np.load(f'./{name}/ba.npy'))
self.wy = tf.Variable(np.load(f'./{name}/wy.npy'))
self.by = tf.Variable(np.load(f'./{name}/by.npy'))
self.weights = [self.wa, self.ba, self.wy, self.by]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment