Skip to content

Instantly share code, notes, and snippets.

@piyo7
Last active August 29, 2015 14:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save piyo7/b4f49821248b80f680a0 to your computer and use it in GitHub Desktop.
Save piyo7/b4f49821248b80f680a0 to your computer and use it in GitHub Desktop.
ディープラーニング勉強会 AutoEncoder ref: http://qiita.com/piyo7/items/60576759430910ffe5be
import scala.math
import scala.util.Random
import MatrixImplicits._ // 自作のimplicit classによるSeqラッパ
class dA(train_N: Int, n_visible: Int, n_hidden: Int, seed: Int) {
val rng = new Random(seed)
var W = Seq.fill(n_hidden, n_visible)(uniform(-1.0 / n_visible, 1.0 / n_visible))
var hbias = Seq.fill(n_hidden)(0.0)
var vbias = Seq.fill(n_visible)(0.0)
def uniform(min: Double, max: Double): Double = rng.nextDouble() * (max - min) + min
def sigmoid(x: Double): Double = 1.0 / (1.0 + math.exp(-x))
def corrupted(x: Seq[Double], p: Double): Seq[Double] = x.map(_ * (if (rng.nextDouble() < p) 0.0 else 1.0))
def encode(x: Seq[Double]): Seq[Double] = ((W mXc x) + hbias).map(sigmoid)
def decode(y: Seq[Double]): Seq[Double] = ((W.T mXc y) + vbias).map(sigmoid)
def train(x: Seq[Double], learning_rate: Double, corruption_level: Double) {
val tilde_x = corrupted(x, corruption_level)
val y = encode(tilde_x)
val z = decode(y)
val L_vbias = x - z
val L_hbias = (W mXc L_vbias) * y * y.map(1.0 - _)
vbias = vbias + L_vbias.map(_ * learning_rate / train_N)
hbias = hbias + L_hbias.map(_ * learning_rate / train_N)
W = W + ((L_hbias cXr tilde_x) + (y cXr L_vbias)).map2(_ * learning_rate / train_N)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment