Last active
October 6, 2020 14:54
-
-
Save altescy/22dd50b56854237dee454998d37d7a66 to your computer and use it in GitHub Desktop.
scikit-learn API Implementation of Replicated Softmax (RSM)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import time | |
import numpy as np | |
from scipy.special import expit | |
from sklearn.base import BaseEstimator | |
from sklearn.base import TransformerMixin | |
from sklearn.utils import check_array | |
from sklearn.utils import check_random_state | |
from sklearn.utils import gen_even_slices | |
from sklearn.utils.extmath import safe_sparse_dot | |
from sklearn.utils import shuffle | |
from sklearn.utils.validation import check_is_fitted | |
def softmax(X): | |
ret = np.exp(X - X.max(axis=1).reshape(-1, 1)) | |
ret = ret / ret.sum(axis=1).reshape(-1, 1) | |
return ret | |
class RSM(BaseEstimator, TransformerMixin): | |
def __init__(self, n_components=128, learning_rate=0.01, batch_size=10, | |
n_iter=10, k_cd_steps=1, momentum=0.9, verbose=False, random_state=None): | |
self.n_components = n_components | |
self.learning_rate = learning_rate | |
self.batch_size = batch_size | |
self.n_iter = n_iter | |
self.k_cd_steps = k_cd_steps | |
self.momentum = momentum | |
self.verbose = verbose | |
self.random_state = random_state | |
def transform(self, X): | |
"""Compute the hidden layer activation probabilities, P(h=1|n(V)=X). | |
""" | |
check_is_fitted(self, "components_") | |
X = check_array(X, accept_sparse='csr', dtype=np.float64) | |
D = np.asarray(X.sum(axis=1)).ravel() | |
return self._mean_hiddens(X, D) | |
def _mean_hiddens(self, v, D): | |
"""Computes the probabilities P(h=1|n(V)). | |
""" | |
p = safe_sparse_dot(v, self.components_.T) + np.outer(D, self.intercept_hidden_) | |
return expit(p) | |
def _mean_visibles(self, h): | |
"""Computes the probabilities P(V=1|h). | |
""" | |
p = np.dot(h, self.components_) + self.intercept_visible_ | |
return softmax(p) | |
def _sample_hiddens(self,v, D, rng): | |
"""Sample from the distribution P(h|n(V)). | |
""" | |
p = self._mean_hiddens(v, D) | |
return ((rng.uniform(size=p.shape) < p), p) | |
def _sample_visibles(self, h, D, rng): | |
"""Sample from the distribution P(n(V)|h). | |
""" | |
p = self._mean_visibles(h) | |
return (np.asarray([rng.multinomial(d, p_) | |
for d, p_ in zip(D, p)]), p) | |
def gibbs(self, v): | |
"""Perform one Gibbs sampling step. | |
""" | |
check_is_fitted(self, "components_") | |
if not hasattr(self, "random_state_"): | |
self.random_state_ = check_random_state(self.random_state) | |
h_ = self._sample_hiddens(v, self.random_state_) | |
v_ = self._sample_visibles(h_, self.random_state_) | |
return v_ | |
def _cd_k(self, h_pos, D, rng): | |
h_neg = h_pos | |
for _ in range(self.k_cd_steps): | |
v_neg, v_mean_neg = self._sample_visibles(h_neg, D, rng) | |
h_neg, h_mean_neg = self._sample_hiddens(v_neg, D, rng) | |
return v_neg, h_mean_neg | |
def _fit(self, v_pos, rng): | |
if not hasattr(self, "lr_"): | |
self.lr_ = self.learning_rate / v_pos.shape[0] | |
D = np.asarray(v_pos.sum(axis=1)).ravel() | |
h_pos, h_mean_pos = self._sample_hiddens(v_pos, D, rng) | |
v_neg, h_mean_neg = self._cd_k(h_pos, D, rng) | |
hv_pos = safe_sparse_dot(v_pos.T, h_mean_pos, dense_output=True).T | |
hv_neg = np.dot(h_mean_neg.T, v_neg) | |
self.update_components_ = (self.momentum * self.update_components_ + | |
hv_pos - hv_neg) | |
self.update_intercept_hidden_ = (self.momentum * self.update_intercept_hidden_ + | |
h_mean_pos.sum(axis=0) - h_mean_neg.sum(axis=0)) | |
self.update_intercept_visible_ = (self.momentum * self.update_intercept_visible_ + | |
np.asarray(v_pos.sum(axis=0)).ravel() - | |
v_neg.sum(axis=0)) | |
self.components_ += self.lr_ * self.update_components_ | |
self.intercept_hidden_ += self.lr_ * self.update_intercept_hidden_ | |
self.intercept_visible_ += self.lr_ * self.update_intercept_visible_ | |
return np.square(v_pos - v_neg).sum() / v_pos.shape[0] | |
def fit(self, X, y=None): | |
X = check_array(X, accept_sparse='csr', dtype=np.float64) | |
n_samples, n_features = X.shape | |
rng = check_random_state(self.random_state) | |
self.components_ = rng.random_sample(size=(self.n_components, n_features)) | |
self.intercept_hidden_ = rng.random_sample(self.n_components) | |
self.intercept_visible_ = rng.random_sample(n_features) | |
self.update_components_ = np.zeros((self.n_components, n_features)) | |
self.update_intercept_hidden_ = np.zeros(self.n_components) | |
self.update_intercept_visible_ = np.zeros(n_features) | |
n_batches = int(np.ceil(float(n_samples) / self.batch_size)) | |
batch_slices = list(gen_even_slices(n_batches * self.batch_size, | |
n_batches, n_samples)) | |
begin = time.time() | |
for iteration in range(self.n_iter): | |
X = shuffle(X, random_state=rng) | |
error = 0 | |
for batch_slice in batch_slices: | |
batch_size = batch_slice.stop - batch_slice.start | |
self.lr_ = self.learning_rate / batch_size | |
error += self._fit(X[batch_slice], rng) | |
if self.verbose: | |
end = time.time() | |
print("[%s] Iteration %d, reconstruction-error = %.2f," | |
" time = %.2fs" | |
% (type(self).__name__, iteration + 1, | |
error / n_batches, end - begin)) | |
begin = end | |
return self |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment