Skip to content

Instantly share code, notes, and snippets.

@ryanrhymes
Last active May 18, 2018 18:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ryanrhymes/2e7c902812a7ae0547e24f7ea743c7e6 to your computer and use it in GitHub Desktop.
Save ryanrhymes/2e7c902812a7ae0547e24f7ea743c7e6 to your computer and use it in GitHub Desktop.
CIFAR10 CNN Example

CIFAR10 VGG Example

This example demonstrates how to build a VGG-like convolutional neural network for CIFAR10 dataset.

#!/usr/bin/env owl
(* This example demonstrates how to build a VGG-like convolutional neural
* network for CIFAR10 dataset.
*)
open Owl
open Neural.S
open Neural.S.Graph
let make_network input_shape =
input input_shape
|> normalisation
|> conv2d [|3;3;3;32|] [|1;1|] ~act_typ:Activation.Relu
|> conv2d [|3;3;32;32|] [|1;1|] ~act_typ:Activation.Relu ~padding:VALID
|> max_pool2d [|2;2|] [|2;2|] ~padding:VALID
|> dropout 0.1
|> conv2d [|3;3;32;64|] [|1;1|] ~act_typ:Activation.Relu
|> conv2d [|3;3;64;64|] [|1;1|] ~act_typ:Activation.Relu ~padding:VALID
|> max_pool2d [|2;2|] [|2;2|] ~padding:VALID
|> dropout 0.1
|> fully_connected 512 ~act_typ:Activation.Relu
|> linear 10 ~act_typ:Activation.(Softmax 1)
|> get_network
let train () =
let x, _, y = Dataset.load_cifar_train_data 1 in
let network = make_network [|32;32;3|] in
Graph.print network;
let params = Params.config
~batch:(Batch.Mini 100) ~learning_rate:(Learning_Rate.Adagrad 0.01)
~checkpoint:(Checkpoint.Epoch 1.) 10.
in
Graph.train ~params network x y
2e7c902812a7ae0547e24f7ea743c7e6
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment