Skip to content

Instantly share code, notes, and snippets.

@mdouze
Created October 19, 2017 16:29
Show Gist options
  • Save mdouze/51187d2eb7d271e5a963a3ccb23ac444 to your computer and use it in GitHub Desktop.
Save mdouze/51187d2eb7d271e5a963a3ccb23ac444 to your computer and use it in GitHub Desktop.
--[[ Copyright 2004-present Facebook. All Rights Reserved.
Load the dynamic library involved in swig calls for Faiss
+ some code to convert input / output arguments
the standard way of importing this is
swigfaiss = require 'faiss_swig'
if you require 'swigfaiss' directly, you get the indexing
structures without the additions that make them easy to use from
Lua.
--]]
require 'torch'
local ffi = require 'ffi'
local swigfaiss
g_force_faiss_Makefile = g_force_faiss_Makefile or
os.getenv('FORCE_FAISS_MAKEFILE')
--[[
Two global variables tell us what we should load:
- g_swigfaiss_use_gpu: if true, load swigfaiss_gpu. The GPU version
is a superset of the normal one, but has additional dependencies
of course.
- g_force_faiss_Makefile: if true, load the version compiled with
the Makefile rather than the version compiled by the Facebook
build system.
--]]
local module_name
if g_swigfaiss_use_gpu then module_name = 'swigfaiss_gpu'
else module_name = 'swigfaiss' end
if g_force_faiss_Makefile then
local my_path = '/fbsource/fbcode/deeplearning/projects/faiss'
local soname = os.getenv ("HOME") .. my_path .. '/lua/' .. module_name .. '.so'
local luaopen_swigfaiss = package.loadlib(
soname, 'luaopen_' .. module_name)
assert (luaopen_swigfaiss, 'could not load .so, check link flags')
swigfaiss = luaopen_swigfaiss()
else
-- works when compiled in fbcode
if g_swigfaiss_use_gpu then
swigfaiss = require 'swigfaiss_gpu'
else
swigfaiss = require 'swigfaiss'
end
end
--[[
The train, add and search methods should accept Lua tensors as arguments.
To do this we:
- rename th swig-wrapped version of each function with a _c suffix
- allocate and convert the arguments to C
- call the _c version
The code below is relatively primitive. It assumes that all classes
whose name starts with Index or GpuIndex are the ones to be processed.
--]]
local function replace_c_method (class, name, f, allow_fail)
local mt_index = getmetatable(class)['.instance']['.fn']
if allow_fail and not mt_index[name] then
return
end
-- move the C function aside
mt_index[name .. '_c'] = mt_index[name]
-- replace it with a Lua-friendly version
mt_index[name] = f
end
local function replace_search_method (class)
local f = function (index, xq, k, zero_based)
assert (xq:size (2) == index.d,
'vectors have incorrect dimension for search')
local nq = xq:size (1)
local D = torch.FloatTensor (nq, k)
local I = torch.LongTensor (nq, k)
index:search_c (nq, swigfaiss.float_ptr (xq), k,
swigfaiss.float_ptr (D),
swigfaiss.long_ptr (I))
if zero_based == nil then
I:add (1)
end
return D, I
end
replace_c_method (class, 'search', f)
end
local function replace_train_method (class)
local f = function (index, x)
assert (x:size (2) == index.d,
'vectors have incorrect dimension for train')
index:train_c (x:size (1), swigfaiss.float_ptr (x))
end
replace_c_method (class, 'train', f)
end
local function replace_add_method (class)
local f = function (index, x)
assert (x:size (2) == index.d,
'vectors have incorrect dimension for add')
index:add_c (x:size (1), swigfaiss.float_ptr (x))
end
replace_c_method (class, 'add', f)
end
local function replace_add_with_ids_method (class, allow_fail)
local f = function (index, x, ids)
assert (x:size (2) == index.d)
local n = x:size(1)
if ids then
assert (ids:size(1) == n)
index:add_with_ids_c (n, swigfaiss.float_ptr (x),
swigfaiss.long_ptr(ids))
else
index:add_with_ids_c (n, swigfaiss.float_ptr (x), nil)
end
end
replace_c_method (class, 'add_with_ids', f, allow_fail)
end
-- Make a few setters/getters that access Index variables directly.
-- We may move to more setters/getters in C++ as well.
local function add_set_get (class)
local mt_index = getmetatable(class)['.instance']['.fn']
mt_index['get_n'] = function (index) return index.ntotal end
mt_index['set_verbose'] = function (index, verbose)
index.verbose = verbose
end
end
local function add_set_nprobe (class)
local mt_index = getmetatable(class)['.instance']['.fn']
mt_index['set_nprobe'] = function (index, nprobe)
index.nprobe = nprobe
end
end
-- go over all fields, pick out the Index classes
for name, class in pairs (swigfaiss) do
-- require('fb.debugger').enter()
if ((name:match ('Index.*') or name:match ('GpuIndex.*')) and
getmetatable (class)) then
replace_train_method (class)
replace_add_method (class)
replace_search_method (class)
replace_add_with_ids_method (class, true)
add_set_get (class)
if name:match ('IndexIVF.*') then
add_set_nprobe (class)
end
end
end
-- A few additional method replacements
local function replace_range_search_method (class)
local f = function (index, xq, threshold)
local nq = xq:size (1)
assert (xq:size(2) == index.d)
-- TODO use an object that allocates tensors directly
local rc = swigfaiss.RangeSearchResult (nq)
-- the actual range search
index:range_search_c (nq, swigfaiss.float_ptr (xq), threshold, rc)
-- copy the results to Tensors
local lims = torch.LongTensor (nq + 1)
swigfaiss.memcpy (swigfaiss.long_ptr (lims), rc.lims, (nq + 1) * 8)
local nres = lims[nq + 1]
local D = torch.FloatTensor (nres)
local I = torch.LongTensor (nres)
if nres > 0 then
swigfaiss.memcpy (swigfaiss.float_ptr (D), rc.distances, nres * 4)
swigfaiss.memcpy (swigfaiss.uint64_t_ptr (I), rc.labels, nres * 8)
I:add (1) -- Lua indexing...
lims:add (1)
end
return lims, D, I
end
replace_c_method (class, 'range_search', f)
end
replace_range_search_method (swigfaiss.IndexIVFFlat)
replace_range_search_method (swigfaiss.IndexFlat)
local function replace_vector_transform_apply (class)
local f = function (vt, x)
assert (x:size (2) == vt.d_in)
local y = torch.FloatTensor (x:size(1), vt.d_out)
vt:apply_noalloc (x:size(1), swigfaiss.float_ptr (x),
swigfaiss.float_ptr (y))
return y
end
local f2 = function (vt, x)
assert (x:size (2) == vt.d_in, 'incorrect train data dim')
vt:train_c (x:size(1), swigfaiss.float_ptr (x))
end
local f3 = function (vt, y)
assert (y:size (2) == vt.d_out)
local x = torch.FloatTensor (y:size(1), vt.d_in)
vt:reverse_transform_c (y:size(1), swigfaiss.float_ptr (y),
swigfaiss.float_ptr (x))
return x
end
replace_c_method (class, 'apply', f, true)
replace_c_method (class, 'train', f2, true)
replace_c_method (class, 'reverse_transform', f3, true)
end
replace_vector_transform_apply (swigfaiss.VectorTransform)
replace_vector_transform_apply (swigfaiss.LinearTransform)
replace_vector_transform_apply (swigfaiss.ExternalTransform)
replace_vector_transform_apply (swigfaiss.RemapDimensionsTransform)
replace_vector_transform_apply (swigfaiss.OPQMatrix)
replace_vector_transform_apply (swigfaiss.PCAMatrix)
local function replace_encode (class)
local f = function (codec, x)
assert (x:size (2) == codec.d)
local y = torch.ByteTensor (x:size(1), codec.code_size)
codec:encode_c (x:size(1), swigfaiss.float_ptr (x),
swigfaiss.uint8_t_ptr (y))
return y
end
local f2 = function (vt, x)
assert (x:size (2) == vt.d)
vt:train_c (x:size(1), swigfaiss.float_ptr (x))
end
replace_c_method (class, 'encode', f, true)
replace_c_method (class, 'train', f2, true)
end
replace_encode (swigfaiss.BinaryCode)
local function replace_codec_functions (class)
local f = function (codec, x)
assert (x:size (2) == codec.d)
codec:train_c (x:size(1), swigfaiss.float_ptr (x))
end
local f2 = function (codec, x)
assert (x:size (2) == codec.d)
local y = torch.ByteTensor (x:size(1), codec.code_size)
codec:compute_codes_c (swigfaiss.float_ptr (x),
swigfaiss.uint8_t_ptr (y),
x:size(1))
return y
end
local f3 = function (codec, x)
assert (x:size (2) == codec.code_size)
local y = torch.FloatTensor (x:size(1), codec.d)
codec:decode_c (
swigfaiss.uint8_t_ptr (x),
swigfaiss.float_ptr (y),
x:size(1)
)
return y
end
replace_c_method (class, 'train', f, true)
replace_c_method (class, 'compute_codes', f2, true)
replace_c_method (class, 'decode', f3, true)
end
replace_codec_functions (swigfaiss.ProductQuantizer)
local AsyncIndexSearch = torch.class (module_name .. '.AsyncIndexSearch')
function AsyncIndexSearch:__init (index, xq, k, zero_based)
assert (index.d == xq:size(2))
local nq = xq:size (1)
self.D = torch.FloatTensor (nq, k)
self.I = torch.LongTensor (nq, k)
self.zero_based = zero_based
self.c = swigfaiss.AsyncIndexSearchC (
index, nq, swigfaiss.float_ptr (xq), k,
swigfaiss.float_ptr (self.D), swigfaiss.long_ptr (self.I))
end
function AsyncIndexSearch:join ()
self.c:join ()
if self.zero_based == nil then
self.I:add (1)
end
return self.D, self.I
end
swigfaiss.kmeans_clustering_c = swigfaiss.kmeans_clustering
swigfaiss.kmeans_clustering = function (x, k)
local d = x:size(2)
local centroids = torch.FloatTensor (k, d)
local res = swigfaiss.kmeans_clustering_c (
d, x:size(1), k, swigfaiss.float_ptr (x),
swigfaiss.float_ptr(centroids))
return centroids, res
end
swigfaiss.kmeans_clustering_gpu_c = swigfaiss.kmeans_clustering_gpu
swigfaiss.kmeans_clustering_gpu = function (ngpu, x, k, useFloat16)
local d = x:size(2)
if useFloat16 == nil then
useFloat16 = false
end
local centroids = torch.FloatTensor (k, d)
local res = swigfaiss.kmeans_clustering_gpu_c (
ngpu, d, x:size(1), k, swigfaiss.float_ptr (x),
swigfaiss.float_ptr(centroids),
useFloat16, false)
return centroids, res
end
--[[
A few low-level utility functions to manipulate swig objects
--]]
-- this is how SWIG stores the reference to any pointer
ffi.cdef [[
typedef struct {
void *type;
int own; /* 1 if owned & must be destroyed */
void *ptr;
} swig_lua_userdata;
]]
-- tell swig not to deallocate the pointer when Lua garbage collects
-- the object
function swigfaiss.disown_pointer (index)
getmetatable(index)['.fn'].__disown (index)
end
-- check whether the pointer is owned by Lua
function swigfaiss.owns_pointer (index)
return ffi.cast ('swig_lua_userdata *', index).own == 1
end
-- make sure that object ofrom that is included with oto does not get
-- deallocated before oto is deleted
function swigfaiss.transfer_ownership (ofrom, index)
swigfaiss.disown_pointer (ofrom)
index.own_fields = true
end
function swigfaiss.vector_float_to_tensor (v)
local t = torch.FloatTensor (v:size())
swigfaiss.memcpy (swigfaiss.float_ptr (t),
v:data(), v:size() * 4)
return t
end
function swigfaiss.tensor_to_vector_float (t, vf)
vf:resize (t:nElement ())
swigfaiss.memcpy (vf:data(), swigfaiss.float_ptr (t),
vf:size() * 4)
end
return swigfaiss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment