Skip to content

Instantly share code, notes, and snippets.

@fissoreg
Last active May 10, 2019 10:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fissoreg/7ce7dc20fe54cd5dd18e0464981f40eb to your computer and use it in GitHub Desktop.
Save fissoreg/7ce7dc20fe54cd5dd18e0464981f40eb to your computer and use it in GitHub Desktop.
Noisy Rectified Linear Units (NReLU) for Boltzmann.jl
using Boltzmann
using MLDatasets.MNIST: testdata
using ImageView
include("relu.jl")
function plot_weights(W, imsize, padding=10)
h, w = imsize
n = size(W, 1)
rows = Int(floor(sqrt(n)))
cols = Int(ceil(n / rows))
halfpad = div(padding, 2)
dat = zeros(rows * (h + padding), cols * (w + padding))
for i=1:n
wt = W[i, :]
wim = reshape(wt, imsize)
wim = wim ./ (maximum(wim) - minimum(wim))
r = div(i - 1, cols) + 1
c = rem(i - 1, cols) + 1
dat[(r-1)*(h+padding)+halfpad+1 : r*(h+padding)-halfpad,
(c-1)*(w+padding)+halfpad+1 : c*(w+padding)-halfpad] = wim
end
imshow(dat)
return dat
end
X, y = testdata() # test data is smaller, no need to downsample
X = Float64.(reshape(X, 784, :))
X = X ./ (maximum(X) - minimum(X))
m = RBM(Degenerate, NReLU, 28 * 28, 100)
fit(m, X, n_epochs=100, randomize=true)
plot_weights(m.W[1:64, :], (28, 28))
# reference: http://www.cs.toronto.edu/~hinton/absps/reluICML.pdf
import Boltzmann: sample, hid_means, vis_means, logistic
struct NReLU end
function hid_means(rbm::RBM{T,V,NReLU}, vis::Array{T, 2}) where {T,V}
p = rbm.W * vis .+ rbm.hbias
return map!(x -> x < 0 ? 0 : x, p, p)
end
function sample(::Type{NReLU}, means::Array{T, 2}) where T
# get the variance for each unit
samples = logistic(means)
for j=1:size(means, 2), i=1:size(means, 1)
samples[i, j] = means[i, j] > 0 ? T(rand(Normal(means[i, j], samples[i, j]))) : 0
end
samples
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment