Skip to content

Instantly share code, notes, and snippets.

@ronekko
Last active June 22, 2018 07:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ronekko/5ff82504d2ffb770a1b36936cd732d9f to your computer and use it in GitHub Desktop.
Save ronekko/5ff82504d2ffb770a1b36936cd732d9f to your computer and use it in GitHub Desktop.
Rotation of images in chainer
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 22 15:40:40 2018
@author: sakurai
"""
import matplotlib.pyplot as plt
import numpy as np
import chainer
import chainer.functions as F
from chainer import cuda, Variable
def deg2rad(deg):
xp = cuda.get_array_module(deg)
return (xp.pi / 180.0) * deg
def rotate_image(x, angle_radian):
"""
Args:
x (Variable):
A batch of images of shape (B, C, H, W).
angle_radian (Variable):
A batch of rotation angles of shape (B,).
Returns:
A (B, C, H, W) shaped variable of the rotated images.
"""
xp = cuda.get_array_module(x)
batch_size, _, height, width = x.shape
# Create rotated points of image coordinates
cos = F.cos(angle_radian)
sin = F.sin(angle_radian)
zero = xp.zeros(batch_size, dtype=np.float32)
theta0 = F.stack((cos, -sin, zero), 1)
theta1 = F.stack((sin, cos, zero), 1)
theta = F.stack((theta0, theta1), 1) # (B, 2, 3), batch of (2, 3) matrices
grid = F.spatial_transformer_grid(theta, (height, width))
# # The above code means like below
# # Create grid points of image coordinates
# x = np.linspace(-1, 1, width)
# y = np.linspace(-1, 1, height)
# x, y = np.meshgrid(x, y)
# grid = np.stack((x, y), 0)
# grid = np.repeat(grid[None], batch_size, 0) # (2, W, H) -> (B, 2, W, H)
# # Rotate each `grid` by each `theta` as rotation matrix
# Create rotated images
rotated_image = F.spatial_transformer_sampler(x, grid)
return rotated_image
if __name__ == '__main__':
angle_degree = [0, 30, 45, 60, 120]
use_gpu = False
batch_size = len(angle_degree)
xp = np if not use_gpu else cuda.cupy
device = -1 if not use_gpu else 0
ds, _ = chainer.datasets.get_mnist(ndim=3)
image, label = chainer.dataset.concat_examples(ds[:batch_size], device)
# rotate image
angle_degree = Variable(xp.asarray(angle_degree, dtype=np.float32))
angle_radian = deg2rad(angle_degree)
rotated_image = rotate_image(image, angle_radian)
for deg, img, img2 in zip(cuda.to_cpu(angle_degree.array),
cuda.to_cpu(image),
cuda.to_cpu(rotated_image.array)):
plt.subplot(1, 2, 1)
plt.matshow(img[0], cmap=plt.cm.gray, fignum=0)
plt.axis('off')
plt.title('Original'.format(deg))
plt.subplot(1, 2, 2)
plt.matshow(img2[0], cmap=plt.cm.gray, fignum=0)
plt.axis('off')
plt.title('Rotated ({} [deg])'.format(deg))
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment