Skip to content

Instantly share code, notes, and snippets.

@lucienne999
Created December 15, 2020 05:58
Show Gist options
  • Save lucienne999/9f3e03f85d622207404c55dccb3789d7 to your computer and use it in GitHub Desktop.
Save lucienne999/9f3e03f85d622207404c55dccb3789d7 to your computer and use it in GitHub Desktop.
import os
import pickle
import torch
import torch.nn as nn
class EVALscore(nn.Module):
"""
Classifier is trained to predict the score between two black/white rope images.
The score is high if they are within a few steps apart, and low other wise.
"""
def __init__(self):
super(EVALscore, self).__init__()
self.LeNet = nn.Sequential(
# input size 2 x 64 x 64. Take 2 black and white images.
nn.Conv2d(2, 64, 4, 2, 1),
nn.LeakyReLU(0.1, inplace=True),
# 64 x 32 x 32
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.1, inplace=True),
# 128 x 16 x 16
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.1, inplace=True),
# Option 1: 256 x 8 x 8
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.1, inplace=True),
# 512 x 4 x 4
nn.Conv2d(512, 1, 4)
)
def Flatten(self,x):
return nn.Sigmoid()(x.view(x.size()[0], -1))
def forward(self, x1, x2):
stacked = torch.cat([x1, x2], dim=1)
out = self.Flatten(stacked)
return out
def covert_to_tensor(ndict):
new_dict = {}
for key, value in ndict.items():
new_dict[key] = torch.from_numpy(value)
return new_dict
def load_pkl(filename):
with open(filename, 'rb') as f:
return pickle.load(f)
if __name__ == "__main__":
model = EVALscore()
numpy_dict = load_pkl('/home/tusimple/Projects/startup/draft/classifier.pkl')[0] # load return tuple
model_dict = covert_to_tensor(numpy_dict)
model.load_state_dict(model_dict)
# print(model.state_dict()['LeNet.0.bias'], numpy_dict['LeNet.0.bias'])
print( model.state_dict()['LeNet.0.bias'].numpy() == numpy_dict['LeNet.0.bias']) # all True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment