Skip to content

Instantly share code, notes, and snippets.

@qfgaohao
Last active December 18, 2018 06:02
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save qfgaohao/597cb3ccc6ebd1b544a0a14f90fce1dc to your computer and use it in GitHub Desktop.
Save qfgaohao/597cb3ccc6ebd1b544a0a14f90fce1dc to your computer and use it in GitHub Desktop.
demonstrates how to train a model, init weights from another source (transfer learning), save models to pb and pbtxt files.
import numpy as np
from caffe2.python import (
brew,
model_helper,
optimizer,
workspace,
utils,
)
from caffe2.proto import caffe2_pb2
from caffe2.python.predictor import mobile_exporter
def gen_data(batch_size=2):
x = np.random.randn(batch_size, 2)
y = np.dot(x, np.array([[1.0], [2.0]])) + 0.5
return x, y
x, y = gen_data()
print("-------------Training data sample--------------")
print("x", x)
print('y', y)
print('\n\n')
def create_net(model):
return brew.fc(train_model, 'X', 'y_pred', dim_in=2, dim_out=1)
workspace.ResetWorkspace()
train_model = model_helper.ModelHelper('regression model')
y_pred = create_net(train_model)
dist = train_model.SquaredL2Distance(['Y_gt', y_pred], "dist")
loss = train_model.AveragedLoss(dist, "loss")
# Add the gradient operators and setup the SGD algorithm
train_model.AddGradientOperators([loss])
optimizer.build_sgd(train_model, base_learning_rate=0.01)
x, y = gen_data()
# Prime the workspace with some data
workspace.FeedBlob("Y_gt",y.astype(np.float32))
workspace.FeedBlob("X",x.astype(np.float32))
# Run the init net to prepare the workspace then create the net
workspace.RunNetOnce(train_model.param_init_net)
workspace.CreateNet(train_model.net)
# Train the model or inject the weigths from somewhere
# Inject our desired initial weights and bias
print("you can just inject weights from somewhere without training.")
workspace.FeedBlob("y_pred_w",np.random.randn(1, 2).astype(np.float32))
workspace.FeedBlob("y_pred_b",np.array([0.]).astype(np.float32))
for i in range(500):
x, y = gen_data()
workspace.FeedBlob('Y_gt', y.astype(np.float32))
workspace.FeedBlob('X', x.astype(np.float32))
workspace.RunNet(train_model.net)
# create test net
test_model= model_helper.ModelHelper(name="test_net", init_params=False)
create_net(test_model)
workspace.RunNetOnce(test_model.param_init_net)
workspace.CreateNet(test_model.net, overwrite=True)
# Prime the workspace with some data
data = np.zeros((1,2)).astype('float32')
workspace.FeedBlob("data", data)
workspace.RunNet(test_model.net, 1)
# test, optional
workspace.FeedBlob('data', np.random.randn(5, 2).astype(np.float32))
workspace.RunNet(test_model.net, 1)
print("Testing results:\n")
print(workspace.FetchBlob('y_pred'))
# save the model
print("Save the model to init_net.pb and predict_net.pb")
init_net, predict_net = mobile_exporter.Export(workspace, test_model.net, test_model.params)
with open("init_net.pb", 'wb') as f:
f.write(init_net.SerializeToString())
with open("predict_net.pb", 'wb') as f:
f.write(predict_net.SerializeToString())
print("Save the mode to init_net.pbtxt and predict_net.pbtxt")
with open('init_net.pbtxt', 'w') as f:
f.write(str(init_net))
with open('predict_net.pbtxt', 'w') as f:
f.write(str(predict_net))
@CarlosYeverino
Copy link

CarlosYeverino commented Aug 7, 2018

Hi qfgaohao,

which adaptations would I have to do in order to load the model from .pbtxt files? I use the following loading net method:

def load_net(init_net_path, predict_net_path, device_opts):
    init_def = caffe2_pb2.NetDef()
    with open(init_net_path + '.pb', 'rb') as f:
        init_def.ParseFromString(f.read())
        init_def.device_option.CopyFrom(device_opts)
        workspace.RunNetOnce(init_def.SerializeToString())

    net_def = caffe2_pb2.NetDef()
    with open(predict_net_path + '.pb', 'rb') as f:
        net_def.ParseFromString(f.read())
        net_def.device_option.CopyFrom(device_opts)
        workspace.CreateNet(net_def.SerializeToString(), overwrite=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment