Skip to content

Instantly share code, notes, and snippets.

@nishnik
Created November 14, 2017 17:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nishnik/7cade05a568ab3ec738eded9a9112971 to your computer and use it in GitHub Desktop.
Save nishnik/7cade05a568ab3ec738eded9a9112971 to your computer and use it in GitHub Desktop.
Gets 2 maximum element form max pooling
import numpy as np
y = np.array([[1.0, 7.0, 1.0],[3.0, 4.0, 2], [1, 7, 1]])
x = torch.from_numpy(y)
x = Variable(x)
x = x.resize(1, 3, 3)
stride = 2
new_var = Variable(torch.zeros([x.shape[0], x.shape[1]//stride, x.shape[2]//stride]))
new_var2 = Variable(torch.zeros([x.shape[0], x.shape[1]//stride, x.shape[2]//stride]))
for dim1 in range(x.shape[0]):
tmp = Variable(torch.zeros([x.shape[1]//stride, x.shape[2]//stride, 1]))
tmp2 = Variable(torch.zeros([x.shape[1]//stride, x.shape[2]//stride, 1]))
for i in range(0, x.shape[1], stride):
for j in range(0, x.shape[2], stride):
tmp_max = x[dim1][i][j]
tmp_max2 = x[dim1][i][j]
for k in range(stride):
for m in range(stride):
if (i+k < x.shape[1] and j+m < x.shape[2]):
tmp_max = torch.max(tmp_max, x[dim1][i+k][j+m])
for k in range(stride):
for m in range(stride):
if (i+k < x.shape[1] and j+m < x.shape[2]):
if (tmp_max.data > x[dim1][i+k][j+m].data).all():
if (tmp_max2.data < x[dim1][i+k][j+m].data).all():
tmp_max2 = x[dim1][i+k][j+m]
if (i//stride < tmp.shape[0] and j//stride < tmp.shape[1]):
tmp[i//stride, j//stride] = tmp_max
tmp2[i//stride, j//stride] = tmp_max2
new_var[dim1] = tmp
new_var2[dim1] = tmp2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment