Created
October 19, 2017 15:23
-
-
Save yaroslavvb/dacf41ecbd79298b43623ca696312fb1 to your computer and use it in GitHub Desktop.
toy example of KFAC in pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Times: min: 440.60, median: 452.39, mean: 453.87 | |
import util as u | |
u.check_mkl() | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.autograd import Variable | |
import numpy as np | |
import scipy | |
from torch.autograd.function import Function | |
import common_gd | |
args = common_gd.args | |
args.cuda = not args.no_cuda and torch.cuda.is_available() | |
dtype = np.float32 | |
lambda_=3e-3 | |
lr = 0.2 | |
dsize = 10000 | |
nonlin = torch.sigmoid | |
DO_PRINT = False | |
def _get_output(ctx, arg, inplace=False): | |
if inplace: | |
ctx.mark_dirty(arg) | |
return arg | |
else: | |
return arg.new().resize_as_(arg) | |
forward_list = [] | |
backward_list = [] | |
class Addmm(Function): | |
@staticmethod | |
def forward(ctx, add_matrix, matrix1, matrix2, beta=1, alpha=1, inplace=False): | |
ctx.save_for_backward(matrix1, matrix2) | |
output = _get_output(ctx, add_matrix, inplace=inplace) | |
forward_list.append(matrix2) | |
return torch.addmm(beta, add_matrix, alpha, | |
matrix1, matrix2, out=output) | |
@staticmethod | |
def backward(ctx, grad_output): | |
matrix1, matrix2 = ctx.saved_variables | |
grad_matrix1 = grad_matrix2 = None | |
if ctx.needs_input_grad[1]: | |
grad_matrix1 = torch.mm(grad_output, matrix2.t()) | |
if ctx.needs_input_grad[2]: | |
grad_matrix2 = torch.mm(matrix1.t(), grad_output) | |
if DO_PRINT: | |
print("backward got") | |
print("grad_output", grad_output) | |
print("matrix1", matrix1) | |
print('matrix2', matrix2) | |
print('grad_matrix1', grad_matrix1) | |
# insert dsize correction to put activations/backprops on same scale | |
backward_list.append(grad_output*dsize) | |
return None, grad_matrix1, grad_matrix2, None, None, None | |
def my_matmul(mat1, mat2): | |
output = Variable(mat1.data.new(mat1.data.size(0), mat2.data.size(1))) | |
return Addmm.apply(output, mat1, mat2, 0, 1, True) | |
def regularized_inverse(mat): | |
assert mat.shape[0] == mat.shape[1] | |
ii = torch.eye(mat.shape[0]) | |
if args.cuda: | |
ii = ii.cuda() | |
regmat = mat + lambda_*ii | |
result = torch.from_numpy(scipy.linalg.inv(regmat.cpu().numpy())) | |
if args.cuda: | |
result = result.cuda() | |
return result | |
def t(mat): return torch.transpose(mat, 0, 1) | |
def copy_list(l): | |
new_list = [] | |
for item in l: | |
# new_list.append(np.copy(l.numpy())) | |
new_list.append(item.clone()) | |
return new_list | |
def main(): | |
global forward_list, backward_list, DO_PRINT | |
torch.manual_seed(args.seed) | |
np.random.seed(args.seed) | |
if args.cuda: | |
torch.cuda.manual_seed(args.seed) | |
class Net(nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
W0 = u.ng_init(196, 784) | |
W1 = u.ng_init(784, 196) # fix non-contiguous input | |
self.W0 = nn.Parameter(torch.from_numpy(W0)) | |
self.W1 = nn.Parameter(torch.from_numpy(W1)) | |
def forward(self, input): | |
x = input.view(784, -1) | |
x = nonlin(my_matmul(self.W0, x)) | |
x = nonlin(my_matmul(self.W1, x)) | |
return x.view_as(input) | |
model = Net() | |
if args.cuda: | |
model.cuda() | |
data0 = u.get_mnist_images() | |
data0 = data0[:, :dsize].astype(dtype) | |
data = Variable(torch.from_numpy(np.copy(data0)).contiguous()) | |
if args.cuda: | |
data = data.cuda() | |
model.train() | |
optimizer = optim.SGD(model.parameters(), lr=lr) | |
losses = [] | |
for step in range(10): | |
optimizer.zero_grad() | |
forward_list = [] | |
backward_list = [] | |
output = model(data) | |
err = output-data | |
loss = torch.sum(err*err)/2/dsize | |
loss.backward(retain_graph=True) | |
loss0 = loss.data[0] | |
A = forward_list[:] | |
B = backward_list[::-1] | |
forward_list = [] | |
backward_list = [] | |
noise = torch.from_numpy(np.random.randn(*data.data.shape).astype(dtype)) | |
if args.cuda: | |
noise = noise.cuda() | |
synthetic_data = Variable(output.data+noise) | |
if args.cuda: | |
synthetic_data = synthetic_data.cuda() | |
err2 = output - synthetic_data | |
loss2 = torch.sum(err2*err2)/2/dsize | |
optimizer.zero_grad() | |
backward_list = [] | |
loss2.backward() | |
B2 = backward_list[::-1] | |
# compute whitened gradient | |
pre_dW = [] | |
n = len(A) | |
assert len(B) == n | |
assert len(B2) == n | |
for i in range(n): | |
covA = A[i] @ t(A[i])/dsize | |
covB2 = B2[i]@t(B2[i])/dsize | |
covB = B[i]@t(B[i])/dsize | |
covA_inv = regularized_inverse(covA) | |
whitened_A = regularized_inverse(covA)@A[i] | |
whitened_B = regularized_inverse(covB2.data)@B[i].data | |
pre_dW.append(whitened_B @ t(whitened_A)/dsize) | |
params = list(model.parameters()) | |
assert len(params) == len(pre_dW) | |
for i in range(len(params)): | |
params[i].data-=lr*pre_dW[i] | |
print("Step %3d loss %10.9f"%(step, loss0)) | |
u.record_time() | |
target = 2.360062122 | |
u.summarize_time() | |
assert abs(loss0-target)<1e-9, abs(loss0-target) | |
if __name__=='__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment