Skip to content

Instantly share code, notes, and snippets.

@jzstark
Created September 5, 2019 15:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jzstark/b2c4677462338f9c8c66876009922992 to your computer and use it in GitHub Desktop.
Save jzstark/b2c4677462338f9c8c66876009922992 to your computer and use it in GitHub Desktop.

How I think Algodiff Works

To train a network, we need to first use train_generic:

let train_generic ?state ?params ?(init_model=true) nn x y =
    if init_model = true then init nn;
    let f = forward nn in
    let b = backward nn in
    let u = update nn in
    let s = save nn in
    let p = match params with
      | Some p -> p
      | None   -> Optimise.Params.default ()
    in
    Optimise.minimise_network ?state p f b u s x y

We can see that it is actually a wrapper around minimise_network :

let minimise_network ?state params forward backward update save x y =
    ...
    let iterate i =
      let xt, yt = bach_fun x y i in
      let yt', ws = forward xt in
      let loss = loss_fun yt yt' in
      let loss = Maths.(loss / (_f (Mat.row_num yt |> float_of_int))) in
      let reg = ...
      let loss = Maths.(loss + reg) in
      let ws, gs' = backward loss in
      loss |> primal', ws, gs'
    in

    ...

    while Checkpoint.(state.stop = false) do
      let loss', ws, gs' = iterate Checkpoint.(state.current_batch) in
      ...
    done

What are the forward and backward functions here?

let forward nn x = mktag (tag ()) nn; run x nn, mkpar nn

let backward nn y = reverse_prop (_f 1.) y; mkpri nn, mkadj nn 

These two operations are called at each iteration. Now we can see the process from a high level.

Forward Phase

First, differentiate between two kinds of neurons: the ones that contain weights to update, such as Conv2D, and I call them Type A neurons; the rest that only do calculation, such as MaxPool2D, and I call it Type B neurons. Hereafter I'll use "layer" and "neuron" interchangeably. Each neuron normally contains several nodes, each node has type t:

type t =
  | F   of A.elt                                  (* constructor of float numbers *)
  | Arr of A.arr                                  (* constructor of ndarrays *)
  | DF  of t * t * int                            (* primal, tangent, tag *)
  | DR  of t * t ref * trace_op * int ref * int   (* primal, adjoint, op, fanout, tag *)

For example, if the run function of one node is Maths.(relu (x - _f a)), it will generate 3 nodes. I don't think we need to use DF value here; I also understand the primal value of a DR as the value of this operation itself, and adjoint value as gradient hereafter.

(wait, what is primal anyway? weight? output value? The graph is definitely hold in op not here...)

  1. mktag (tag ()) nn: for each layer in a neural network, if it is Type A, change each of its parameters to a DR value by calling make_reverse; if it is a type B neuron, do nothing.
let make_reverse p i = DR (p, ref (zero p), Noop, ref 0, i)

Note that the DR values created here are ALL Noop operations, which means they are nothing more than placeholders at this stage.

  1. run x nn: connect all the existing operations into a graph by running this network layer by layer, regardless whether it's type A or B. The whole graph is accumulated to the output node. Note that the run function of each neuron uses operation from Algodiff.Maths rather than normal math operations. Let's look an example :
(* module Conv2D *)
let run x l = Maths.((conv2d ~padding:l.padding x l.w l.stride) + l.b)

Here both l.w and l.b are already set to DR placeholders. x is a t output value from the previous neuron. How the conv2d operation is implemented in Algodiff then?

and conv2d ?padding a b s =
  let ff a b =
    match a, b with
    | Arr a, Arr b -> Arr A.(conv2d ?padding a b s)
    | _            -> error_binop "conv2d" a b
  in
  let fd a b = conv2d ?padding a b s in
  ...
  let r_d_c a b = Conv2D_D_C (a, b, s) in
  op_d_d_d a b ff fd _ _ _ _ r_d_c _

Here a and b are input and kernel respectively, both of type t. For simplicity, we only look at the case where input is DR and Kernel is a constant Arr: (One question: wait, you just said that the l.w is already set to DR in the previous step, how then could it be Arr now? What has changed it?)

and op_d_d_d a b ff fd df_da df_db df_dab r_d_d r_d_c r_c_d =
  match a, b with
  | DR (ap, _, _, _, ai), Arr _bp -> let cp = fd ap b in DR (cp, ref (zero cp), r_d_c a b, ref 0, ai)
  |...

So what Maths.conv2d does is this: first calculate the result value by updating the existing primal value of a DR (let cp = fd ap b => let cp = conv2d ?padding ap b s), ignore the gradient value (just set to zero: ref (zero cp)), set the Noop operation to Conv2D_D_C (r_d_c a b), and set the tag as is (ai).

What I don't quite understand is function fd; it calls it self, why? I temporarily interpret it as "calculating result value", but it is quite likely wrong.

The translation of Type B neuron is similar. For example, for the Maxpool2D neuron:

let run x l = Maths.(max_pool2d l.padding x l.kernel l.stride)

and max_pool2d padding a b s =
  let ff = function
    | Arr a    -> Arr A.(max_pool2d ~padding a b s)
    | _        -> error_uniop "max_pool2d" a
  in
  let fd a = max_pool2d padding a b s in
  let df _cp _ap _at = failwith "max_pool2d:df" in
  let r a = Maxpool2D_D (a, padding, b, s) in
  op_d_d a ff fd df r

and op_d_d a ff fd df r =
  match a with
  | DF (ap, at, ai)      -> ...
  | DR (ap, _, _, _, ai) -> let cp = fd ap in DR (cp, ref (zero cp), r a, ref 0, ai)
  | ap                   -> ff ap

If the input is DR, then this operation similarly adds a DR node to the graph; otherwise the ff is called, and a Arr node is added.

After a forward pass is finished, we get one output DR value. But it's much more than an output value; it actually contains a whole computation graph in its op:

and trace_op =
  | Conv1D_D_C  of t * t * int array
  | Maxpool2D_D of t * padding * int array * int array
  ...

The output of each run function is accumulated in this way into a graph of ts.

  1. The final step mkpar nn is simple: return the parameters of each layer in an array, which is t array array.

Backward Phase

We have already get the graph in the output DR value y from forward pass; now let's apply backward step on it. Note that the backward step is actually applied on a loss value, which append some extra nodes at the end of the output graph from forward pass, but let's ignore it for now.

Starting from reverse_prop (_f 1.) y, which simply comprises of two steps:

let reverse_prop v x =
  reverse_reset x;
  reverse_push v x

let reverse_reset x =
  let rec reset xs =
    match xs with
    | [] -> ()
    | x :: t -> 
      | DR (_ap, aa, ao, af, _ai) -> (
        aa := reset_zero !aa;
        match ao with
        | Noop -> reset t
          ....)
        else reset t
        )
      | _ -> reset t
  in
  reset [x]

let reverse\_push v x =
  let open Maths in
  let rec push xs =
    match xs with
    | []          -> ()
    | (v, x) :: t -> (
        match x with
        | DR (ap, aa, ao, af, _ai) -> (
            aa := Maths.(!aa + v);
            match ao with
            | Noop -> push t
            | Conv2D_D_C (a, b, s)  -> push ((conv2d_backward_input a b s !aa, a) :: t)
            | ....
          )
        | _ -> push t
      )
  in
  push [(v, x)]
  1. No magic for the reverse_reset function. Starting from the root node y, for each node: 1) set my own gradient to 0; 2) add my parents to the Stack; 3) process the first element of the Stack until it is empty.

  2. reverse_push is little bit complex but similar. Starting from (v, y), where v for y is 1, for each node: 1) update my gradient by adding my current gradient with v; 2) calculate the v for my parents; 3) add (v, parent) to the Stack; 4) process the first element of the Stack until the it is empty. In both steps, if the node is not a DR, then just ignore it.

  3. After one backward step, the gradient of each node is updated.

The rest is easy: mkpri nn, mkadj nn: get the weight value and gradient of each node in arrays if it contains any.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment