Skip to content

Instantly share code, notes, and snippets.

@felipetavares
Last active March 6, 2021 23:35
Show Gist options
  • Save felipetavares/96bf70026a99a0e5461ef6892397fd88 to your computer and use it in GitHub Desktop.
Save felipetavares/96bf70026a99a0e5461ef6892397fd88 to your computer and use it in GitHub Desktop.
Simple backprop neural net
#lang racket
(require math/matrix)
(define (sigmoid x) (/ (exp x) (+ (exp x) 1)))
(define (sigmoid~ x) (* (sigmoid x) (- 1 (sigmoid x))))
(define (tanh~ x) (/ 1 (expt (cosh x) 2)))
(define activation sigmoid)
(define activation~ sigmoid~)
; Calculates output and intermediate output for each layer
(define (feedforward input weights biases)
(if (empty? weights)
(list)
; Compute the intermediate output, z, and the final, out
(let* ([z (matrix+ (matrix* (first weights) input) (first biases))]
[out (matrix-map activation z)])
; Return it + the result from the next layer. If the next is the last it
; will be the final output, otherwise the intermediate output (z) for
; that layer.
;
; This makes the final return a list in the form:
;
; (z1 o1 z2 o2 ... zn on)
(cons
(list z out)
; Compute the next layer with the output from this one
(feedforward out (rest weights) (rest biases))))))
(define (output-error z actual expected)
(matrix-map * (matrix- actual expected) (matrix-map activation~ z)))
(define (propagate-error outputs weights next-error)
(if (empty? outputs)
(list next-error)
(cons
next-error
(propagate-error
(rest outputs)
(rest weights)
(matrix-map *
(matrix* (first weights) next-error)
(matrix-map activation~ (first (first outputs))))))))
; NOTE: outputs, weights need to be (reverse)d
(define (backprop weights outputs expected-output)
(propagate-error (rest outputs)
weights
(let ([out (first outputs)])
(output-error (first out) (second out) expected-output))))
(define (adjust-weights weight out err epsilon)
(let-values ([(w h) (matrix-shape weight)])
(build-matrix w h
(lambda (j k) (-
(matrix-ref weight j k)
(* (matrix-ref out k 0) (matrix-ref err j 0) epsilon))))))
(define (adjust-parameters weights biases outputs errors)
(let ([epsilon 1])
(if (empty? outputs)
(list)
(cons
(list
(adjust-weights (first weights) (second (first outputs)) (first errors) epsilon)
(matrix- (first biases) (matrix-scale (first errors) epsilon)))
(adjust-parameters
(rest weights)
(rest biases)
(rest outputs)
(rest errors))))))
(define (gradient-descent weights biases inputs expected-output (n 1))
(let* ([outputs (feedforward inputs weights biases)]
[errors (reverse (backprop (reverse weights) (reverse outputs) expected-output))]
[parameters (reverse (adjust-parameters weights biases outputs errors))]
[weights (map (lambda (x) (first x)) parameters)]
[biases (map (lambda (x) (second x)) parameters)])
(if (<= n 1)
(values weights biases)
(gradient-descent weights biases inputs expected-output (sub1 n)))))
(define (matrix-size m)
(call-with-values (lambda () (matrix-shape m)) *))
(define (parameters-in-net weights biases)
(+ (* (length weights) (matrix-size (first weights)))
(* (length biases) (matrix-size (first biases)))))
(let* ([weights (list (matrix (((random) (random))
((random) (random))))
(matrix (((random) (random))
((random) (random))))
(matrix (((random) (random))
((random) (random))))
(matrix (((random) (random))
((random) (random))))
(matrix (((random) (random))
((random) (random))))
)]
[biases (list (col-matrix ((random) (random)))
(col-matrix ((random) (random)))
(col-matrix ((random) (random)))
(col-matrix ((random) (random)))
(col-matrix ((random) (random)))
)]
[inputs (col-matrix (1 1))]
[expected-output (col-matrix (0.5 0.25))])
(begin
(display (format "Training model with ~a parameters" (parameters-in-net weights biases)))
(newline)
(let-values ([(weights biases) (gradient-descent weights biases inputs expected-output 256)])
(begin
(display (second (last (feedforward inputs weights biases)))) (newline)))))
@felipetavares
Copy link
Author

felipetavares commented Mar 6, 2021

❯ racket net.rkt
Training model with 30 parameters
#<array #(2 1) #[0.5000005829140142 0.2500826441823957]>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment