Created
June 28, 2012 21:20
-
-
Save npinto/3013981 to your computer and use it in GitHub Desktop.
Gradients of Rolling View
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| from skimage.util.shape import view_as_windows, view_as_blocks | |
| from sthor.util import filter_pad2d | |
| def f_g(x): | |
| xr = view_as_windows(x, (2, 2)).copy() | |
| #print xr | |
| loss = 0.5 * (xr**2.).sum() | |
| #xr[:] = 1 | |
| xrT = xr.transpose((0, 2, 1, 3)) | |
| #print xrT | |
| xrT2d = xrT.reshape(4, 4) | |
| #print xrT2d | |
| xrT2dp = filter_pad2d(np.atleast_3d(xrT2d), (3, 3))[..., 0] | |
| #print xrT2dp | |
| g = view_as_blocks(xrT2dp, (2, 2)) | |
| g = g.sum(-1).sum(-1) | |
| return loss, g | |
| x_shape = (3, 3) | |
| x = np.arange(np.prod(x_shape)).reshape(x_shape).astype('f') | |
| #print g | |
| def v(i): | |
| o = np.zeros(x.size) | |
| o[i] = 1 | |
| return o.reshape(x.shape) | |
| #print x | |
| e = 1e-9 | |
| gt = np.array([((f_g(x + e*v(i))[0] - f_g(x)[0]) / e) for i in xrange(x.size)]) | |
| gt = gt.reshape(x.shape) | |
| gv = f_g(x)[1] | |
| print gt | |
| print gv |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
currently wrong (only working with 3x3 => 2x2)