Skip to content

Instantly share code, notes, and snippets.

@altescy
Last active October 6, 2020 14:54
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save altescy/22dd50b56854237dee454998d37d7a66 to your computer and use it in GitHub Desktop.
Save altescy/22dd50b56854237dee454998d37d7a66 to your computer and use it in GitHub Desktop.
scikit-learn API Implementation of Replicated Softmax (RSM)
# -*- coding: utf-8 -*-
import time
import numpy as np
from scipy.special import expit
from sklearn.base import BaseEstimator
from sklearn.base import TransformerMixin
from sklearn.utils import check_array
from sklearn.utils import check_random_state
from sklearn.utils import gen_even_slices
from sklearn.utils.extmath import safe_sparse_dot
from sklearn.utils import shuffle
from sklearn.utils.validation import check_is_fitted
def softmax(X):
ret = np.exp(X - X.max(axis=1).reshape(-1, 1))
ret = ret / ret.sum(axis=1).reshape(-1, 1)
return ret
class RSM(BaseEstimator, TransformerMixin):
def __init__(self, n_components=128, learning_rate=0.01, batch_size=10,
n_iter=10, k_cd_steps=1, momentum=0.9, verbose=False, random_state=None):
self.n_components = n_components
self.learning_rate = learning_rate
self.batch_size = batch_size
self.n_iter = n_iter
self.k_cd_steps = k_cd_steps
self.momentum = momentum
self.verbose = verbose
self.random_state = random_state
def transform(self, X):
"""Compute the hidden layer activation probabilities, P(h=1|n(V)=X).
"""
check_is_fitted(self, "components_")
X = check_array(X, accept_sparse='csr', dtype=np.float64)
D = np.asarray(X.sum(axis=1)).ravel()
return self._mean_hiddens(X, D)
def _mean_hiddens(self, v, D):
"""Computes the probabilities P(h=1|n(V)).
"""
p = safe_sparse_dot(v, self.components_.T) + np.outer(D, self.intercept_hidden_)
return expit(p)
def _mean_visibles(self, h):
"""Computes the probabilities P(V=1|h).
"""
p = np.dot(h, self.components_) + self.intercept_visible_
return softmax(p)
def _sample_hiddens(self,v, D, rng):
"""Sample from the distribution P(h|n(V)).
"""
p = self._mean_hiddens(v, D)
return ((rng.uniform(size=p.shape) < p), p)
def _sample_visibles(self, h, D, rng):
"""Sample from the distribution P(n(V)|h).
"""
p = self._mean_visibles(h)
return (np.asarray([rng.multinomial(d, p_)
for d, p_ in zip(D, p)]), p)
def gibbs(self, v):
"""Perform one Gibbs sampling step.
"""
check_is_fitted(self, "components_")
if not hasattr(self, "random_state_"):
self.random_state_ = check_random_state(self.random_state)
h_ = self._sample_hiddens(v, self.random_state_)
v_ = self._sample_visibles(h_, self.random_state_)
return v_
def _cd_k(self, h_pos, D, rng):
h_neg = h_pos
for _ in range(self.k_cd_steps):
v_neg, v_mean_neg = self._sample_visibles(h_neg, D, rng)
h_neg, h_mean_neg = self._sample_hiddens(v_neg, D, rng)
return v_neg, h_mean_neg
def _fit(self, v_pos, rng):
if not hasattr(self, "lr_"):
self.lr_ = self.learning_rate / v_pos.shape[0]
D = np.asarray(v_pos.sum(axis=1)).ravel()
h_pos, h_mean_pos = self._sample_hiddens(v_pos, D, rng)
v_neg, h_mean_neg = self._cd_k(h_pos, D, rng)
hv_pos = safe_sparse_dot(v_pos.T, h_mean_pos, dense_output=True).T
hv_neg = np.dot(h_mean_neg.T, v_neg)
self.update_components_ = (self.momentum * self.update_components_ +
hv_pos - hv_neg)
self.update_intercept_hidden_ = (self.momentum * self.update_intercept_hidden_ +
h_mean_pos.sum(axis=0) - h_mean_neg.sum(axis=0))
self.update_intercept_visible_ = (self.momentum * self.update_intercept_visible_ +
np.asarray(v_pos.sum(axis=0)).ravel() -
v_neg.sum(axis=0))
self.components_ += self.lr_ * self.update_components_
self.intercept_hidden_ += self.lr_ * self.update_intercept_hidden_
self.intercept_visible_ += self.lr_ * self.update_intercept_visible_
return np.square(v_pos - v_neg).sum() / v_pos.shape[0]
def fit(self, X, y=None):
X = check_array(X, accept_sparse='csr', dtype=np.float64)
n_samples, n_features = X.shape
rng = check_random_state(self.random_state)
self.components_ = rng.random_sample(size=(self.n_components, n_features))
self.intercept_hidden_ = rng.random_sample(self.n_components)
self.intercept_visible_ = rng.random_sample(n_features)
self.update_components_ = np.zeros((self.n_components, n_features))
self.update_intercept_hidden_ = np.zeros(self.n_components)
self.update_intercept_visible_ = np.zeros(n_features)
n_batches = int(np.ceil(float(n_samples) / self.batch_size))
batch_slices = list(gen_even_slices(n_batches * self.batch_size,
n_batches, n_samples))
begin = time.time()
for iteration in range(self.n_iter):
X = shuffle(X, random_state=rng)
error = 0
for batch_slice in batch_slices:
batch_size = batch_slice.stop - batch_slice.start
self.lr_ = self.learning_rate / batch_size
error += self._fit(X[batch_slice], rng)
if self.verbose:
end = time.time()
print("[%s] Iteration %d, reconstruction-error = %.2f,"
" time = %.2fs"
% (type(self).__name__, iteration + 1,
error / n_batches, end - begin))
begin = end
return self
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment