Skip to content

Instantly share code, notes, and snippets.

@johncorring
Created September 28, 2018 20:31
Show Gist options
  • Save johncorring/d735675e75add96fbdfbcc40fa00f3ba to your computer and use it in GitHub Desktop.
Save johncorring/d735675e75add96fbdfbcc40fa00f3ba to your computer and use it in GitHub Desktop.
import numpy as np
import skimage.io
import skimage.transform
def rescale(img, input_height, input_width):
aspect = img.shape[1] / float(img.shape[0])
if aspect > 1:
return skimage.transform.resize(img, (input_width, int(aspect * input_height)))
elif aspect < 1:
return skimage.transform.resize(img, (int(input_width/aspect), input_height))
else:
return skimage.transform.resize(img, (input_width, input_height))
def crop_center(img, cropx, cropy):
y, x, c = img.shape
startx = x // 2 - (cropx // 2)
starty = y // 2 - (cropy // 2)
return img[starty:starty+cropy, startx:startx+cropx]
def prepare_image(img_path):
img = skimage.io.imread(img_path)
img = skimage.img_as_float(img)
img = rescale(img, 227, 227)
img = crop_center(img, 227, 227)
img = img.swapaxes(1, 2).swapaxes(0, 1) # HWC to CHW dimension
img = img[(2, 1, 0), :, :] # RGB to BGR color order
img = img * 255 - 128 # Subtract mean = 128
return img.astype(np.float32)
import os, glob, random
def make_batch(iterable, batch_size=1):
length = len(iterable)
for index in range(0, length, batch_size):
yield iterable[index:min(index + batch_size, length)]
class DogsCatsDataset(object):
""" Dogs and cats dataset reader """
def __init__(self, split="train", data_dir="dogs-vs-cats/"):
self.categories = {"dog": 0, "cat": 1}
self.image_files = list(glob.glob(os.path.join(data_dir, split, "*.jpg")))
#print(self.image_files)
self.labels = [self.categories.get(os.path.basename(path).strip().split(".")[0], -1)
for path in self.image_files]
def __getitem__(self, index):
image = prepare_image(self.image_files[index])
label = self.labels[index]
return image, label
def __len__(self):
return len(self.labels)
def read(self, batch_size=50, shuffle=True):
"""Read (image, label) pairs in batch"""
order = list(range(len(self)))
if shuffle:
random.shuffle(order)
for batch in make_batch(order, batch_size):
images, labels = [], []
for index in batch:
image, label = self[index]
images.append(image)
labels.append(label)
yield np.stack(images).astype(np.float32), np.stack(labels).astype(np.int32).reshape((batch_size,))
from caffe2.python.modeling import initializers
from caffe2.python.modeling.parameter_info import ParameterTags
from caffe2.proto import caffe2_pb2
from caffe2.python import core, workspace, model_helper, optimizer, brew
PREDICT_NET = "/home/john/Code/models/squeezenet/predict_net.pb"
INIT_NET = "/home/john/Code/models/squeezenet/init_net.pb"
def AddPredictNet(model, predict_net_path):
predict_net_proto = caffe2_pb2.NetDef()
with open(predict_net_path, "rb") as f:
predict_net_proto.ParseFromString(f.read())
model.net = core.Net(predict_net_proto)
# Fix dimension incompatibility
model.Squeeze("softmaxout", "softmax", dims=[2, 3])
def AddInitNet(model, init_net_path, out_dim=2, params_to_learn=None):
init_net_proto = caffe2_pb2.NetDef()
with open(init_net_path, "rb") as f:
init_net_proto.ParseFromString(f.read())
# Define params to learn in the model.
for c, op in enumerate(init_net_proto.op):
param_name = op.output[0]
if params_to_learn is None or op.output[0] in params_to_learn:
"""
for arg_ in op.arg:
if arg_.name == 'shape':
if param_name.endswith("_w"):
arg_.ClearField('ints')
arg_.ints.extend([out_dim, 512,1,1])
else:
arg_.ClearField('ints')
arg_.ints.extend([out_dim])
"""
print(c, param_name)
tags = (ParameterTags.WEIGHT if param_name.endswith("_w")
else ParameterTags.BIAS)
model.create_param(
param_name=param_name,
shape=op.arg[0],
initializer=initializers.ExternalInitializer(),
tags=tags,
)
#print(model.net.Proto())
# Remove conv10_w, conv10_b initializers at (50, 51)
init_net_proto.op.pop(51)
init_net_proto.op.pop(50)
# Add new initializers for conv10_w, conv10_b
model.param_init_net = core.Net(init_net_proto)
model.param_init_net.XavierFill([], "conv10_w", shape=[out_dim, 512, 1, 1])
model.param_init_net.ConstantFill([], "conv10_b", shape=[out_dim])
def AddTrainingOperators(model, softmax, label):
xent = model.LabelCrossEntropy([softmax, label], "xent")
loss = model.AveragedLoss(xent, "loss")
brew.accuracy(model, [softmax, label], "accuracy")
model.AddGradientOperators([loss])
opt = optimizer.build_sgd(
model,
base_learning_rate=0.001,
policy="fixed",
momentum=0.9,
weight_decay=0.0001
)
for param in model.GetOptimizationParamInfo():
opt(model.net, model.param_init_net, param)
workspace.ResetWorkspace()
train_model = model_helper.ModelHelper("train_net")
def SetDeviceOption(model, device_option):
# Clear op-specific device options and set global device option.
for net in ("net", "param_init_net"):
net_def = getattr(model, net).Proto()
net_def.device_option.CopyFrom(device_option)
for op in net_def.op:
# Some operators are CPU-only.
if op.output[0] not in ("optimizer_iteration", "iteration_mutex"):
op.ClearField("device_option")
op.ClearField("engine")
setattr(model, net, core.Net(net_def))
device_option = caffe2_pb2.DeviceOption()
device_option.device_type = caffe2_pb2.CUDA
device_option.cuda_gpu_id = 0
SetDeviceOption(train_model, device_option)
# Initialization.
train_dataset = DogsCatsDataset(split="train", data_dir="/home/john/Data/scratch/all")
for image, label in train_dataset.read(batch_size=1):
workspace.FeedBlob("data", image, device_option=device_option)
workspace.FeedBlob("label", label, device_option=device_option)
break
AddPredictNet(train_model, PREDICT_NET)
AddInitNet(train_model, INIT_NET, params_to_learn=["conv10_w", "conv10_b"]) # Use None to learn everything.
AddTrainingOperators(train_model, "softmax", "label")
print("initialized")
workspace.RunNetOnce(train_model.param_init_net)
workspace.CreateNet(train_model.net, overwrite=True)
shtyp = workspace.InferShapesAndTypes([train_model.net])
#st = shtyp[0]
#print(shtyp)
#for n in st.keys():
# print(n ,st[n])
print("created")
# Main loop.
batch_size = 50
print_freq = 50
losses = []
for epoch in range(5):
for index, (image, label) in enumerate(train_dataset.read(batch_size)):
print("reading")
print(image.shape, label.shape)
workspace.FeedBlob("data", image, device_option=device_option)
workspace.FeedBlob("label", label, device_option=device_option)
print(label)
print("running")
workspace.RunNet(train_model.net)
print("accuracy")
accuracy = float(workspace.FetchBlob("accuracy"))
loss = workspace.FetchBlob("loss").mean()
losses.append(loss)
if index % print_freq == 0:
print("[{}][{}/{}] loss={}, accuracy={}".format(
epoch, index, int(len(train_dataset) / batch_size),
loss, accuracy))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment