Skip to content

Instantly share code, notes, and snippets.

@hccho2
Last active January 25, 2021 13:41
Show Gist options
  • Save hccho2/ebf352d843897e351110bdeaa3bb00a5 to your computer and use it in GitHub Desktop.
Save hccho2/ebf352d843897e351110bdeaa3bb00a5 to your computer and use it in GitHub Desktop.
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
H=4;W=3
pad = 1
stride=2
kernel_size = 3
image = np.arange(H*W).reshape(1,1,H,W).astype(np.float32) # (N,C,H,W) random image
image = torch.tensor(image)
print("random image", image)
weight = torch.randn(1,1,kernel_size,kernel_size)
out = F.conv_transpose2d(image,weight,padding=pad,stride=stride)
print(f'output size: {(H-1)*stride - 2*pad+ + kernel_size} x {(W-1)*stride - 2*pad+ + kernel_size}')
print(f'weight: {weight}')
print(f'output: {out}')
#####################################################
#####################################################
# transpose convolution by normal convolution
image_stride = np.zeros((1,1,(H-1)*stride+1,(W-1)*stride+1))
image_stride[:,:,::stride,::stride] = image
image_stride = torch.tensor(image_stride,dtype=torch.float32)
print('new image', image_stride)
image_padded = F.pad(image_stride,(kernel_size-1 - pad,kernel_size-1 - pad,kernel_size-1 - pad,kernel_size-1 - pad))
print("padded image", image_padded)
out2 = F.conv2d(image_padded,torch.flip(weight,[2,3]),padding=0,stride = 1)
print(f'manual output: {out2}')
@hccho2
Copy link
Author

hccho2 commented Jan 25, 2021

def unfold_kernel_stride(kernel, input_size, stride=1):
    # kernel을 (HxW) x (OH x OW) 크기로 변형해 준다.
    k_h, k_w = kernel.shape
    i_h, i_w = input_size
    o_h, o_w = (i_h-k_h)//stride + 1, (i_w-k_w)//stride + 1  # stride=1인 경우

    # construct 1d conv toeplitz matrices for each row of the kernel
    toeplitz_list  = []
    for r in range(k_h):
        toe = toeplitz(c=(kernel[r,0], *np.zeros(i_w-k_w)), r=(*kernel[r], *np.zeros(i_w-k_w)))
        toeplitz_list.append( toe[::stride] ) 

    # construct toeplitz matrix of toeplitz matrices (just for padding=0)
    h_blocks, w_blocks = o_h, i_h
    h_block, w_block = toeplitz_list[0].shape

    W_conv = np.zeros((h_blocks, h_block, w_blocks, w_block))

    for i, B in enumerate(toeplitz_list):
        for j in range(o_h):
            W_conv[j, :, i+j*stride, :] = B

    W_conv.shape = (h_blocks*h_block, w_blocks*w_block)

    return W_conv.T

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment