Get the weights from here
To run the test:
`python circle_vs_ellipse.py circle_vs_ellipse --n_train 100 --n_test 100 --restore_path cornet_z_epoch25.pth.tar - -j 4 --batch_size 100
import sys | |
import os | |
import argparse | |
import time | |
import glob | |
import pickle | |
import subprocess | |
import shlex | |
import io | |
from collections import OrderedDict | |
import numpy as np | |
import pandas | |
import tqdm | |
import fire | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn as nn | |
import torchvision | |
import sklearn.linear_model | |
import skimage.draw | |
from PIL import Image | |
Image.warnings.simplefilter('ignore') | |
np.random.seed(0) | |
torch.manual_seed(0) | |
torch.backends.cudnn.benchmark = True | |
normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
parser = argparse.ArgumentParser(description='ImageNet Training') | |
parser.add_argument('--data_path', default='./', | |
help='path to ImageNet folder that contains train and val folders') | |
parser.add_argument('-o', '--output_path', default=None, | |
help='path for storing ') | |
parser.add_argument('--model', choices=['Z', 'R', 'S'], default='Z', | |
help='which model to train') | |
parser.add_argument('--times', default=5, type=int, | |
help='number of time steps to run the model (only R and S models)') | |
parser.add_argument('--ngpus', default=1, type=int, | |
help='number of GPUs to use') | |
parser.add_argument('-j', '--workers', default=4, type=int, | |
help='number of data loading workers') | |
parser.add_argument('--epochs', default=20, type=int, | |
help='number of total epochs to run') | |
parser.add_argument('--batch_size', default=256, type=int, | |
help='mini-batch size') | |
parser.add_argument('--lr', '--learning_rate', default=.1, type=float, | |
help='initial learning rate') | |
parser.add_argument('--step_size', default=10, type=int, | |
help='after how many epochs learning rate should be decreased 10x') | |
parser.add_argument('--momentum', default=.9, type=float, help='momentum') | |
parser.add_argument('--weight_decay', default=1e-4, type=float, | |
help='weight decay ') | |
FLAGS, FIRE_FLAGS = parser.parse_known_args() | |
class Flatten(nn.Module): | |
""" | |
Helper module for flattening input tensor to 1-D for the use in Linear modules | |
""" | |
def forward(self, x): | |
return x.view(x.size(0), -1) | |
class Identity(nn.Module): | |
""" | |
Helper module that stores the current tensor. Useful for accessing by name | |
""" | |
def forward(self, x): | |
return x | |
class CORblock_Z(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): | |
super().__init__() | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, | |
stride=stride, padding=kernel_size // 2) | |
self.nonlin = nn.ReLU(inplace=True) | |
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |
self.output = Identity() # for an easy access to this block's output | |
def forward(self, inp): | |
x = self.conv(inp) | |
x = self.nonlin(x) | |
x = self.pool(x) | |
x = self.output(x) # for an easy access to this block's output | |
return x | |
def CORnet_Z(): | |
model = nn.Sequential(OrderedDict([ | |
('V1', CORblock_Z(3, 64, kernel_size=7, stride=2)), | |
('V2', CORblock_Z(64, 128)), | |
('V4', CORblock_Z(128, 256)), | |
('IT', CORblock_Z(256, 512)), | |
('decoder', nn.Sequential(OrderedDict([ | |
('avgpool', nn.AdaptiveAvgPool2d(1)), | |
('flatten', Flatten()), | |
('linear', nn.Linear(512, 1000)), | |
('output', Identity()) | |
]))) | |
])) | |
# weight initialization | |
for m in model.modules(): | |
if isinstance(m, (nn.Conv2d, nn.Linear)): | |
nn.init.xavier_uniform_(m.weight) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.BatchNorm2d): | |
m.weight.data.fill_(1) | |
m.bias.data.zero_() | |
return model | |
class GenEllipse(torch.utils.data.Dataset): | |
def __init__(self, imsize=224, transform=None, min_aspect_ratio=.8): | |
self.imsize = imsize | |
self.transform = transform | |
self.min_aspect_ratio = min_aspect_ratio | |
def __len__(self): | |
return sys.maxsize | |
def __getitem__(self, index): | |
c_radius = np.random.uniform(1, self.imsize / 2) | |
r_radius = np.random.uniform(1, self.min_aspect_ratio * c_radius) | |
rr, cc = skimage.draw.ellipse( | |
r=np.random.uniform(c_radius, self.imsize - c_radius), | |
c=np.random.uniform(c_radius, self.imsize - c_radius), | |
r_radius=r_radius, | |
c_radius=c_radius, | |
rotation=np.random.uniform(-np.pi, np.pi) | |
) | |
im = np.zeros((self.imsize, self.imsize, 3)).astype('float32') | |
im[rr, cc] = 1 | |
if self.transform is not None: | |
im = self.transform(im) | |
return im, r_radius / c_radius | |
class GenCircle(torch.utils.data.Dataset): | |
def __init__(self, imsize=224, transform=None): | |
self.imsize = imsize | |
self.transform = transform | |
def __len__(self): | |
return sys.maxsize | |
def __getitem__(self, index): | |
c_radius = np.random.uniform(1, self.imsize / 2) | |
rr, cc = skimage.draw.ellipse( | |
r=np.random.uniform(c_radius, self.imsize - c_radius), | |
c=np.random.uniform(c_radius, self.imsize - c_radius), | |
r_radius=c_radius, | |
c_radius=c_radius, | |
rotation=np.random.uniform(-np.pi, np.pi) | |
) | |
im = np.zeros((self.imsize, self.imsize, 3)).astype('float32') | |
im[rr, cc] = 1 | |
if self.transform is not None: | |
im = self.transform(im) | |
return im, 1 | |
def circle_vs_ellipse(n_train=100, n_test=10, restore_path=None, imsize=224, | |
use_gpu=False): | |
model = CORnet_Z() | |
model = torch.nn.DataParallel(model) | |
if use_gpu: | |
model = model.cuda() | |
if restore_path is not None: | |
ckpt_data = torch.load(restore_path, map_location='cpu') | |
model.load_state_dict(ckpt_data['state_dict']) | |
model.eval() | |
def _get_features(n, kind): | |
def _store_feats(layer, inp, output): | |
"""An ugly but effective way of accessing intermediate model features | |
""" | |
_model_feats.append(np.reshape(output, (len(output), -1)).numpy()) | |
handle = model_layer.register_forward_hook(_store_feats) | |
dataset = kind(imsize, | |
torchvision.transforms.Compose([ | |
torchvision.transforms.ToTensor(), | |
normalize, | |
])) | |
data_loader = torch.utils.data.DataLoader(dataset, | |
batch_size=FLAGS.batch_size, | |
shuffle=False, | |
num_workers=FLAGS.workers, | |
pin_memory=True) | |
with torch.no_grad(): | |
model_feats = [] | |
aspect_ratios = [] | |
for i, ims in enumerate(data_loader): | |
if i * FLAGS.batch_size >= n: | |
break | |
aspect_ratios.append(ims[1]) | |
_model_feats = [] | |
model(ims[0]) | |
model_feats.append(_model_feats[0]) | |
model_feats = np.concatenate(model_feats)[:n] | |
aspect_ratios = np.concatenate(aspect_ratios)[:n] | |
handle.remove() | |
return model_feats, aspect_ratios | |
model_layer = model._modules['module'].decoder.flatten | |
train_circles, _ = _get_features(n_train, kind=GenCircle) | |
train_ellipses, _ = _get_features(n_train, kind=GenEllipse) | |
test_circles, test_car = _get_features(n_test, kind=GenCircle) | |
test_ellipses, test_ear = _get_features(n_test, kind=GenEllipse) | |
clf = sklearn.svm.LinearSVC() | |
train_feats = np.concatenate([train_circles, train_ellipses], axis=0) | |
train_labels = np.concatenate([np.zeros(len(train_circles)), | |
np.ones(len(train_ellipses))]) | |
test_feats = np.concatenate([test_circles, test_ellipses], axis=0) | |
test_labels = np.concatenate([np.zeros(len(test_circles)), | |
np.ones(len(test_ellipses))]) | |
test_ars = np.concatenate([test_car, test_ear]) | |
clf.fit(train_feats, train_labels) | |
preds = clf.predict(test_feats) | |
df = pandas.DataFrame(np.stack([preds, test_labels, test_ars]).T, | |
columns=['prediction', 'actual', 'aspect_ratio']) | |
df['acc'] = df.prediction == df.actual | |
print('accuracy:', df.acc.mean()) | |
agg = df.groupby(pandas.cut(df.aspect_ratio, np.arange(0, 1.2, .1), | |
include_lowest=True, right=False)).acc.mean() | |
print(agg) | |
if __name__ == '__main__': | |
fire.Fire(command=FIRE_FLAGS) |