Skip to content

Instantly share code, notes, and snippets.

@lazuxd
Created July 20, 2021 13:11
Show Gist options
  • Save lazuxd/790952f831bbedf956a5f7c90f321d09 to your computer and use it in GitHub Desktop.
Save lazuxd/790952f831bbedf956a5f7c90f321d09 to your computer and use it in GitHub Desktop.
def save(self, name: str) -> None:
mkdir(f'./{name}')
mkdir(f'./{name}/weights')
with open(f'./{name}/vocabulary.txt', 'w') as f:
f.write('[separator]'.join(self.vocab))
with open(f'./{name}/inter_time_step_size.txt', 'w') as f:
f.write(str(self.inter_time_step_size))
with open(f'./{name}/unit_type.txt', 'w') as f:
f.write(self.unit_type)
with open(f'./{name}/depth.txt', 'w') as f:
f.write(str(self.depth))
if self.unit_type == 'gru':
for s in ['wr', 'br', 'wu', 'bu', 'wa', 'ba']:
for i in range(self.depth):
np.save(f'./{name}/weights/{s}_{i}.npy',
getattr(self, s)[i].numpy())
else:
for s in ['wu', 'bu', 'wf', 'bf', 'wo', 'bo', 'wc', 'bc']:
for i in range(self.depth):
np.save(f'./{name}/weights/{s}_{i}.npy',
getattr(self, s)[i].numpy())
np.save(f'./{name}/weights/wy.npy', self.wy.numpy())
np.save(f'./{name}/weights/by.npy', self.by.numpy())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment