Skip to content

Instantly share code, notes, and snippets.

@qnkhuat
Last active February 2, 2022 04:03
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Embed
What would you like to do?
MNIST reader in clojure
;; Data download from: http://yann.lecun.com/exdb/mnist/
(ns cnn.core
(:import java.io.File
java.io.FileInputStream))
(defn u4->int [arr]
;; convert an array of 4 bytes to int
(loop [r 0 i 3 x arr]
(if (= i -1)
r
(recur
(bit-or
r
(bit-shift-left (first x) (* i 8)))
(dec i)
(rest x)))))
(defn to-unsinged
[b]
(bit-and b 0xff))
(defn read-n-bytes
[is n]
(let [data (byte-array n)]
(.read is data)
;; clojure use signed-byte whereas mnist is unsigned
(mapv to-unsinged data)))
(defn read-images
[path]
(let [f (File. path)
is (FileInputStream. f)
magic-number (u4->int (read-n-bytes is 4))
_ (when-not (= magic-number 2051)
(throw (ex-info "Magic number should be 2051 for image file" {})))
n (u4->int (read-n-bytes is 4))
rows (u4->int (read-n-bytes is 4))
cols (u4->int (read-n-bytes is 4))
data-len (* n rows cols)
data-arr (read-n-bytes is data-len)]
(when-not (= data-len (count data-arr))
(throw (ex-info (format "Incorrect data length, Should be: %d, got: %d" data-len (count data-arr)) {})))
(partition (* rows cols) data-arr)))
(defn read-labels
[path]
(let [f (File. path)
is (FileInputStream. f)
magic-number (u4->int (read-n-bytes is 4))
_ (when-not (= magic-number 2049)
(throw (ex-info "Magic number should be 2049 for label file" {})))
n (u4->int (read-n-bytes is 4))
data-arr (read-n-bytes is n)]
(when-not (= n (count data-arr))
(throw (ex-info (format "Incorerct data length, Should be: %d, got: %d" n (count data-arr)) {})))
data-arr))
;; Tested on Mac M1 Max 32GB Ram
(time (read-images "../mnist/train-images-idx3-ubyte"))
;; read an array of 47.040.000 elems
;; "Elapsed time: 1564.851583 msecs" (r/map without type hint)
;; "Elapsed time: 1475.194334 msecs" (r/map with type hint)
;; "Elapsed time: 1143.02050 msecs" (use mapv instead of r/map)
;; "Elapsed time: 6397.154542 msecs" (use map instead of mapv)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment