Skip to content

Instantly share code, notes, and snippets.

@jongwook
Last active April 20, 2018 00:07
Show Gist options
  • Save jongwook/a51bb55e1bce849a7d6fc502339bc388 to your computer and use it in GitHub Desktop.
Save jongwook/a51bb55e1bce849a7d6fc502339bc388 to your computer and use it in GitHub Desktop.
Truncating the last 7 bits of the weights in a saved Keras model
import argparse
import h5py
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument('input')
parser.add_argument('output')
args = parser.parse_args()
with h5py.File(args.output, 'w') as out:
def visit(name, obj):
if hasattr(obj, 'dtype'):
assert obj.dtype == np.float32
# round the 7 last bits
uint32 = np.fromstring(obj.value.tostring(), dtype=np.uint32)
uint32 += 0x00000040
uint32 &= 0xffffff80
truncated = np.fromstring(uint32.tostring(), dtype=np.float32).reshape(obj.value.shape)
print('truncated', name, ', shape =', truncated.shape)
dataset = out.create_dataset(name, data=truncated)
for key, value in obj.attrs.items():
print('copying attribute', key, '=>', value, 'of dataset', name)
dataset.attrs[key] = value
else:
group = out.create_group(name)
for key, value in obj.attrs.items():
print('copying attribute', key, '=>', value, 'of group', name)
group.attrs[key] = value
with h5py.File(args.input) as f:
for key, value in f.attrs.items():
print('copying global attribute', key, '=>', value)
out.attrs[key] = value
f.visititems(visit)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment