Skip to content

Instantly share code, notes, and snippets.

@gasagna
Last active August 29, 2015 14:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gasagna/ba8c39c8f8ad63fd1d6d to your computer and use it in GitHub Desktop.
Save gasagna/ba8c39c8f8ad63fd1d6d to your computer and use it in GitHub Desktop.
Julia wrapper for FANN library
# This type definition is generated automatically using Clang.jl
# It contains training data used by FANN
type fann_train_data
errno_f::fann_errno_enum
error_log::Ptr{FILE}
errstr::Ptr{Uint8}
num_data::Uint32
num_input::Uint32
num_output::Uint32
input::Ptr{Ptr{fann_type}}
output::Ptr{Ptr{fann_type}}
end
# took some inspiration from Julia NLopt package
typealias _DataSet Ptr{fann_train_data}
# Create DataSet from input data.
#
# Observations are organised in columns in the matrix X, to reflect
# the data format required by FANN, which uses an array of pointers
# to arrays to store the observations. Hence, we make use of Julia
# column-major order to save a copy of the data.
#
# Parameters
# ----------
# X : n_feat x n_obs matrix
# y : n_obs vector
#
type DataSet
data::_DataSet
function DataSet(d::_DataSet)
d = new(d)
finalizer(d, destroy)
d
end
function DataSet(X::Matrix, y::Vector)
if size(X, 2) != length(y)
error("sizes of X and y do not match")
end
num_input, num_data = size(X)
num_output = 1
d = ccall((:fann_create_train, libfann), _DataSet,
(Uint32, Uint32, Uint32),
num_data, num_input, num_output)
if d == C_NULL
error("Error in fann_create_train")
end
# ==== BEGIN TRICKY BIT =======
tmp = unsafe_load(d)
tmp.input = pointer([pointer(X, i) for i = 1:num_input:length(X)])
tmp.output = pointer([pointer(y, i) for i = 1:length(y)])
# ==== END TRICKY BIT =======
DataSet(d)
end
end
Base.convert(::Type{_DataSet}, d::DataSet) = d.data
Base.show(io::IO, d::DataSet) = print(io, "DataSet()")
destroy(d::DataSet) = ccall((:fann_destroy_train, libfann), Void, (_DataSet,), d)
# example inputs
X = [ 1.0 0.0 1.0 1.0;
1.0 1.0 0.0 1.0]
# desired ouputs
y = [1.0, 1.0, 1.0, 4.0]
# create Dataset
dset = DataSet(X, y)
# ~~~~~~~~~~~~~~~~~~~~~~
# The following lines demonstrate that the pointers are not stored
# correctly and no data is available to the c side for subsequent training
# >> indicates the output of running the program
dset |> println
>> DataSet()
dset.data |> println
>> Ptr{fann_train_data} @0x000000000469dd90
unsafe_load(dset.data).input |> println
>> Ptr{Ptr{Float64}} @0x00000000055e39b0
pointer_to_array(unsafe_load(dset.data).input, 4) |> println
>> [Ptr{Float64} @0x0000000002600e70,
Ptr{Float64} @0x0000000002600e78,
Ptr{Float64} @0x0000000002600e80,
Ptr{Float64} @0x0000000002600e88]
arr = pointer_to_array(unsafe_load(dset.data).input, 4)
pointer_to_array(arr[1], 2) |> println
>> [0.0,0.0]
# Crucially this last line gives [0.0, 0.0] and not the first column of X, which is
# [1.0, 1.0]. This means that the dataset contains only zeros. Hence, when a neural
# network is trained on this data it will have the wrong behaviour.
# A few reference web pages
~ definition of fann_create_train
http://leenissen.dk/fann/html/files/fann_train-h.html#fann_create_train
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment