Last active
May 10, 2019 10:23
-
-
Save fissoreg/7ce7dc20fe54cd5dd18e0464981f40eb to your computer and use it in GitHub Desktop.
Noisy Rectified Linear Units (NReLU) for Boltzmann.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Boltzmann | |
using MLDatasets.MNIST: testdata | |
using ImageView | |
include("relu.jl") | |
function plot_weights(W, imsize, padding=10) | |
h, w = imsize | |
n = size(W, 1) | |
rows = Int(floor(sqrt(n))) | |
cols = Int(ceil(n / rows)) | |
halfpad = div(padding, 2) | |
dat = zeros(rows * (h + padding), cols * (w + padding)) | |
for i=1:n | |
wt = W[i, :] | |
wim = reshape(wt, imsize) | |
wim = wim ./ (maximum(wim) - minimum(wim)) | |
r = div(i - 1, cols) + 1 | |
c = rem(i - 1, cols) + 1 | |
dat[(r-1)*(h+padding)+halfpad+1 : r*(h+padding)-halfpad, | |
(c-1)*(w+padding)+halfpad+1 : c*(w+padding)-halfpad] = wim | |
end | |
imshow(dat) | |
return dat | |
end | |
X, y = testdata() # test data is smaller, no need to downsample | |
X = Float64.(reshape(X, 784, :)) | |
X = X ./ (maximum(X) - minimum(X)) | |
m = RBM(Degenerate, NReLU, 28 * 28, 100) | |
fit(m, X, n_epochs=100, randomize=true) | |
plot_weights(m.W[1:64, :], (28, 28)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# reference: http://www.cs.toronto.edu/~hinton/absps/reluICML.pdf | |
import Boltzmann: sample, hid_means, vis_means, logistic | |
struct NReLU end | |
function hid_means(rbm::RBM{T,V,NReLU}, vis::Array{T, 2}) where {T,V} | |
p = rbm.W * vis .+ rbm.hbias | |
return map!(x -> x < 0 ? 0 : x, p, p) | |
end | |
function sample(::Type{NReLU}, means::Array{T, 2}) where T | |
# get the variance for each unit | |
samples = logistic(means) | |
for j=1:size(means, 2), i=1:size(means, 1) | |
samples[i, j] = means[i, j] > 0 ? T(rand(Normal(means[i, j], samples[i, j]))) : 0 | |
end | |
samples | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment