Skip to content

Instantly share code, notes, and snippets.

@kmaehashi
Created January 18, 2017 11:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kmaehashi/d85662929c3e2698daa9845fc516d1c1 to your computer and use it in GitHub Desktop.
Save kmaehashi/d85662929c3e2698daa9845fc516d1c1 to your computer and use it in GitHub Desktop.
Convert NN-based classifier jubatus model into NN jubatus model.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, unicode_literals
import sys
import json
from cStringIO import StringIO
import msgpack
from jubakit.model import JubaModel
def convert_config(cfg):
# Convert NN-based classifier config to NN config
return {
'method': cfg['parameter']['method'],
'parameter': cfg['parameter']['parameter'],
'converter': cfg['converter'],
}
# Load NN-based classifier model file.
modelpath = sys.argv[1]
with open(modelpath, 'rb') as f:
m1 = JubaModel.load_binary(f)
# Extract NN parts from NN-based classifier model.
user_raw = StringIO()
w = user_raw.write
unp = msgpack.Unpacker(StringIO(m1._user_raw))
assert unp.read_array_header(w) == 2 # (user container)
unp.unpack(w) # <version>
assert unp.read_array_header() == 2
assert unp.read_array_header(w) == 2 # <user_data>
unp.unpack(w) # nearest_neighbor_engine_->pack(pk);
unp.skip() # discard: labels_.pack(pk);
unp.unpack(w) # wm_.get_model()->pack(pk);
# Create new NN model.
m2 = JubaModel()
m2.header = m1.header
m2.system = m1.system
m2.system.type = 'nearest_neighbor'
m2.system.config = json.dumps(convert_config(json.loads(m1.system.config)))
m2._user_raw = user_raw.getvalue()
# Recompute CRC32 checksum and field lengths.
m2.fix_header()
# Save NN model file.
with open('extracted_nn_model.jubatus', 'wb') as f:
m2.dump_binary(f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment