Created
February 12, 2021 16:21
Content-image retrieval with torch.rb and annoy.rb (ja)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'magro' | |
require 'torch' | |
require 'torchvision' | |
require 'annoy' | |
# @param src [Numo::NArray] (shape: [height, width, channel]) 入力画像 | |
def preprocessing(src) | |
# 画像の中心を正方形に切り出す. | |
height, width, = src.shape | |
img_size = [height, width].min | |
y_offset = (height - img_size) / 2 | |
x_offset = (width - img_size) / 2 | |
img = src[y_offset...(y_offset + img_size), x_offset...(x_offset + img_size), true].dup | |
# 画像を224x224の大きさにする. | |
img = Magro::Transform.resize(img, height: 224, width: 224) | |
# 画素値を[0, 1]の範囲に正規化する. | |
img = Numo::SFloat.cast(img) / 255.0 | |
# 平均と標準偏差を正規化する. | |
img -= Numo::SFloat[0.485, 0.456, 0.406] | |
img /= Numo::SFloat[0.229, 0.224, 0.225] | |
# [チャンネル, 高さ, 幅]の順に入れ替える. | |
img.transpose(2, 0, 1).dup | |
end | |
# 画像データセットを読み込み前処理を施す. | |
filelist = Dir.glob('icpr2004.imgset/groundtruth/*/*.jpg') | |
images = Numo::SFloat.cast filelist.map { |filename| preprocessing(Magro::IO.imread(filename)) } | |
images_tensor = Torch.from_numo(images) | |
# 学習済みネットワークを読み込み, 最終層を外す. | |
vgg = TorchVision::Models::VGG16.new | |
vgg.load_state_dict(Torch.load('vgg16_.pth')) | |
vgg.classifier.instance_variable_get(:@modules).delete("6") | |
vgg.eval | |
# 特徴ベクトルを得る. | |
features = vgg.forward(images_tensor).to_a | |
# 検索インデックスに特徴ベクトルを登録する. | |
index = Annoy::AnnoyIndex.new(n_features: 4096, metric: 'angular') | |
features.each_with_index { |vec, idx| index.add_item(idx, vec) } | |
# 検索インデックスを作成する. 木の数は10本. | |
index.build(10) | |
# 添字55の特徴ベクトルをクエリに見立てて、5-近傍を検索する. | |
query = features[55] | |
retrieved_ids = index.get_nns_by_vector(query, 5) | |
# 検索結果を出力する. | |
pp retrieved_ids.map { |i| filelist[i] } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment