Skip to content

Instantly share code, notes, and snippets.

@yoshoku
Created Feb 12, 2021
Embed
What would you like to do?
Content-image retrieval with torch.rb and annoy.rb (ja)
require 'magro'
require 'torch'
require 'torchvision'
require 'annoy'
# @param src [Numo::NArray] (shape: [height, width, channel]) 入力画像
def preprocessing(src)
# 画像の中心を正方形に切り出す.
height, width, = src.shape
img_size = [height, width].min
y_offset = (height - img_size) / 2
x_offset = (width - img_size) / 2
img = src[y_offset...(y_offset + img_size), x_offset...(x_offset + img_size), true].dup
# 画像を224x224の大きさにする.
img = Magro::Transform.resize(img, height: 224, width: 224)
# 画素値を[0, 1]の範囲に正規化する.
img = Numo::SFloat.cast(img) / 255.0
# 平均と標準偏差を正規化する.
img -= Numo::SFloat[0.485, 0.456, 0.406]
img /= Numo::SFloat[0.229, 0.224, 0.225]
# [チャンネル, 高さ, 幅]の順に入れ替える.
img.transpose(2, 0, 1).dup
end
# 画像データセットを読み込み前処理を施す.
filelist = Dir.glob('icpr2004.imgset/groundtruth/*/*.jpg')
images = Numo::SFloat.cast filelist.map { |filename| preprocessing(Magro::IO.imread(filename)) }
images_tensor = Torch.from_numo(images)
# 学習済みネットワークを読み込み, 最終層を外す.
vgg = TorchVision::Models::VGG16.new
vgg.load_state_dict(Torch.load('vgg16_.pth'))
vgg.classifier.instance_variable_get(:@modules).delete("6")
vgg.eval
# 特徴ベクトルを得る.
features = vgg.forward(images_tensor).to_a
# 検索インデックスに特徴ベクトルを登録する.
index = Annoy::AnnoyIndex.new(n_features: 4096, metric: 'angular')
features.each_with_index { |vec, idx| index.add_item(idx, vec) }
# 検索インデックスを作成する. 木の数は10本.
index.build(10)
# 添字55の特徴ベクトルをクエリに見立てて、5-近傍を検索する.
query = features[55]
retrieved_ids = index.get_nns_by_vector(query, 5)
# 検索結果を出力する.
pp retrieved_ids.map { |i| filelist[i] }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment