Last active
August 29, 2015 14:07
-
-
Save gasagna/ba8c39c8f8ad63fd1d6d to your computer and use it in GitHub Desktop.
Julia wrapper for FANN library
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This type definition is generated automatically using Clang.jl | |
# It contains training data used by FANN | |
type fann_train_data | |
errno_f::fann_errno_enum | |
error_log::Ptr{FILE} | |
errstr::Ptr{Uint8} | |
num_data::Uint32 | |
num_input::Uint32 | |
num_output::Uint32 | |
input::Ptr{Ptr{fann_type}} | |
output::Ptr{Ptr{fann_type}} | |
end | |
# took some inspiration from Julia NLopt package | |
typealias _DataSet Ptr{fann_train_data} | |
# Create DataSet from input data. | |
# | |
# Observations are organised in columns in the matrix X, to reflect | |
# the data format required by FANN, which uses an array of pointers | |
# to arrays to store the observations. Hence, we make use of Julia | |
# column-major order to save a copy of the data. | |
# | |
# Parameters | |
# ---------- | |
# X : n_feat x n_obs matrix | |
# y : n_obs vector | |
# | |
type DataSet | |
data::_DataSet | |
function DataSet(d::_DataSet) | |
d = new(d) | |
finalizer(d, destroy) | |
d | |
end | |
function DataSet(X::Matrix, y::Vector) | |
if size(X, 2) != length(y) | |
error("sizes of X and y do not match") | |
end | |
num_input, num_data = size(X) | |
num_output = 1 | |
d = ccall((:fann_create_train, libfann), _DataSet, | |
(Uint32, Uint32, Uint32), | |
num_data, num_input, num_output) | |
if d == C_NULL | |
error("Error in fann_create_train") | |
end | |
# ==== BEGIN TRICKY BIT ======= | |
tmp = unsafe_load(d) | |
tmp.input = pointer([pointer(X, i) for i = 1:num_input:length(X)]) | |
tmp.output = pointer([pointer(y, i) for i = 1:length(y)]) | |
# ==== END TRICKY BIT ======= | |
DataSet(d) | |
end | |
end | |
Base.convert(::Type{_DataSet}, d::DataSet) = d.data | |
Base.show(io::IO, d::DataSet) = print(io, "DataSet()") | |
destroy(d::DataSet) = ccall((:fann_destroy_train, libfann), Void, (_DataSet,), d) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# example inputs | |
X = [ 1.0 0.0 1.0 1.0; | |
1.0 1.0 0.0 1.0] | |
# desired ouputs | |
y = [1.0, 1.0, 1.0, 4.0] | |
# create Dataset | |
dset = DataSet(X, y) | |
# ~~~~~~~~~~~~~~~~~~~~~~ | |
# The following lines demonstrate that the pointers are not stored | |
# correctly and no data is available to the c side for subsequent training | |
# >> indicates the output of running the program | |
dset |> println | |
>> DataSet() | |
dset.data |> println | |
>> Ptr{fann_train_data} @0x000000000469dd90 | |
unsafe_load(dset.data).input |> println | |
>> Ptr{Ptr{Float64}} @0x00000000055e39b0 | |
pointer_to_array(unsafe_load(dset.data).input, 4) |> println | |
>> [Ptr{Float64} @0x0000000002600e70, | |
Ptr{Float64} @0x0000000002600e78, | |
Ptr{Float64} @0x0000000002600e80, | |
Ptr{Float64} @0x0000000002600e88] | |
arr = pointer_to_array(unsafe_load(dset.data).input, 4) | |
pointer_to_array(arr[1], 2) |> println | |
>> [0.0,0.0] | |
# Crucially this last line gives [0.0, 0.0] and not the first column of X, which is | |
# [1.0, 1.0]. This means that the dataset contains only zeros. Hence, when a neural | |
# network is trained on this data it will have the wrong behaviour. | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# A few reference web pages | |
~ definition of fann_create_train | |
http://leenissen.dk/fann/html/files/fann_train-h.html#fann_create_train |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment