Skip to content

Instantly share code, notes, and snippets.

@lazuxd
Created July 20, 2021 13:22
Show Gist options
  • Save lazuxd/d4cd5f0f9bac69dc0dbae91a12830b39 to your computer and use it in GitHub Desktop.
Save lazuxd/d4cd5f0f9bac69dc0dbae91a12830b39 to your computer and use it in GitHub Desktop.
def load(self, name: str) -> None:
with open(f'./{name}/vocabulary.txt', 'r') as f:
vocabulary = f.read().split('[separator]')
with open(f'./{name}/inter_time_step_size.txt', 'r') as f:
inter_time_step_size = int(f.read())
with open(f'./{name}/unit_type.txt', 'r') as f:
unit_type = f.read()
with open(f'./{name}/depth.txt', 'r') as f:
depth = int(f.read())
self._init(vocabulary, inter_time_step_size,
unit_type, depth)
weights_names = []
filenames = listdir(f'./{name}/weights')
filenames.sort()
for filename in filenames:
if filename in ['wy.npy', 'by.npy']:
continue
attr_name, index = filename.replace('.npy', '').split('_')
index = int(index)
if index == 0:
setattr(self, attr_name, [])
weights_names.append(attr_name)
getattr(self, attr_name).append(
tf.Variable(np.load(f'./{name}/weights/{filename}')))
self.wy = tf.Variable(np.load(f'./{name}/weights/wy.npy'))
self.by = tf.Variable(np.load(f'./{name}/weights/by.npy'))
self.weights = [getattr(self, weight_name)
for weight_name in weights_names]
self.weights.extend([self.wy, self.by])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment