Skip to content

Instantly share code, notes, and snippets.

@yoshoku
Created February 11, 2021 00:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yoshoku/2e4c2edc045ca5b7979fc702364e7302 to your computer and use it in GitHub Desktop.
Save yoshoku/2e4c2edc045ca5b7979fc702364e7302 to your computer and use it in GitHub Desktop.
Image Recognition with VGG-16 Network in Ruby (ja)
require 'magro'
require 'json'
require 'torch'
require 'torchvision'
# 学習済みのVGG-16 Networkを読み込む.
vgg = TorchVision::Models::VGG16.new
vgg.load_state_dict(Torch.load('vgg16_.pth'))
# 画像を読み込む.
img = Magro::IO.imread('A_Golden_Retriever-9_(Barras).JPG')
# 画像の中心を正方形に切り出す.
height, width, = img.shape
img_size = [height, width].min
y_offset = (height - img_size) / 2
x_offset = (width - img_size) / 2
img = img[y_offset...(y_offset + img_size), x_offset...(x_offset + img_size), true]
# 画像を224x224の大きさにする.
img = Magro::Transform.resize(img, height: 224, width: 224)
# 画素値を[0, 1]の範囲に正規化する.
img = Numo::SFloat.cast(img) / 255.0
# 画像をtorch.rbのtensorに変換し, [チャンネル, 高さ, 幅]の順に入れ替える.
img_torch = Torch.from_numo(img).permute(2, 0, 1)
# 平均と標準偏差を正規化する.
mean = Torch.tensor([0.485, 0.456, 0.406])
std = Torch.tensor([0.229, 0.224, 0.225])
normalize = TorchVision::Transforms::Normalize.new(mean, std)
normalize.call(img_torch)
# tensorを [1, 3, 224, 224] の形にする.
img_torch = img_torch.expand(1, -1, -1, -1)
# 学習済みモデルに、前処理した画像を入力する.
vgg.eval
out = vgg.forward(img_torch)
# 最終層の出力で最も値の大きい要素の添字を得る.
class_idx = out.numo[0, true].max_index
# 添字に対応するImageNetのクラスを出力する.
imagenet_classes = JSON.load(File.read('imagenet_class_index.json'))
puts "class: #{imagenet_classes[class_idx.to_s].last}"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment