Skip to content

Instantly share code, notes, and snippets.

@DhairyaLGandhi
Last active January 24, 2020 11:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DhairyaLGandhi/b2bfb95e330177d236601e1837f889e0 to your computer and use it in GitHub Desktop.
Save DhairyaLGandhi/b2bfb95e330177d236601e1837f889e0 to your computer and use it in GitHub Desktop.

Flux has taken some major strides in the past couple of years since it has been out. So this piece is to talk a little bit about what has changed.

flux-logo

Starting with a bit of housekeeping. This piece will introduce some basic guidelines to Julia programming and should hopefully help with your understanding of the language and using it with a few neat tricks. Another task is to clarify what Flux and its ecosystem isn't. It isn't a strictly deep learning library, although it does have most of the primitives for deep learning defined. It is essentially a framework for differentiable programming.

For a TL;DR, differentiable programming ($\partial$P) is a way of treating arbitrary programs as differentiable. Put it easily, it is a generalisation of the way we treat deep learning as consisting of a forward pass and a backwards pass. It applies the chain rule (refer the equation below) to every operatoin that takes place in a program. The record and sequence of the operations to every element in the code is the code itself! Specifically, it is the AST of the code.

$$ \frac{ d{ (f \circ g)(x) }}{ d{x}} = \frac{\partial f}{\partial g} \times \frac{\partial g}{\partial x} $$

It replaces the standard neural network, with basically any other code and the model subtly melts away from being a sigular entity (think Sequential from Keras), to a part of general logic that you wish to implement. The adjoints are defined just as they are for any other differentiable function, again generalised from the mathematical priciples, to implementing logic consistent with what a deconstruction of the actions would look like.

How the adjoints are defined

Consider a regular function

function f(arg1::CustomType, arg2, arg3...)
  transform1 = f1(arg1)
  transform2 = f2(transform1, arg3...)
  result = f3(transform1, transform2)
  result
end
.
# going down in the abstraction
.

function f1(arg)
  result = operations_with_some_concrete_types(arg)
end

The function f can accept arguments with any type, including user defined ones. When we call on this function, it executes a bevy of other functions, ultimately ending with some basic operations involving concrete types (be they Arrays, Numbers, Symbols, etc), let's call them primitives. Let's now teach Julia how to differentiate operations involving these primitives. This would involve defining the adjoint for sqrt for a real number, for example.

$$ \frac{d{\sqrt{x}}}{d{x}} = \frac{1}{2\sqrt{x}} $$

Which can be expressed as

@adjoint Base.sqrt(x::Real) = Base.sqrt(x), Δ -> ( Δ * inv(2 * sqrt(x)),)

The process and intuition to writing appropriate adjoints is a different blog.

If one could keep track of higher level operations, and define the adjoints on the primitives, we can essentially "solve back", accumulating the resulting gradients from all the transforms (with the help of the adjoints from the primitives), and maintaining some structures, like constructing NamedTuples with the appropriate keys, we can express any operation as differentiable. The backwards pass flow would basically go something like f3 --> f2 --> operations_with_some_concrete_types. This way we can traverse our code (specifically, the intermediate representation), and generate the backwards passes on the fly.

The cool part about this approach is that if we were to define the adjoints for the primitives or the base functions of a programming language, we can get any arbitrary program to be differentiated, and even support custom types and packages, almost for free. Add in an ideal optimising compiler, and these backwards passes become efficient too!

To give an example, the forward pass can be thought of as the process of tying your shoelaces, and the backwards pass is when we untie them by pulling the two ends apart.

For a lot of this to work as expected, though, it is pertinent that the base language on top of which this entire machinery is built, exposes meaningful expressions of its intermediate representation that can be used to infer the backwards passes on the fly, and this is precisely what Julia does, given its history of hackability. Flux takes this hackability, and runs with it to the point of making sure that the entire library is focussed on inviting people to its source code and in fact extending it with their own layers and definitions and optimisers and what have you. This is a tough ask, since it means anticipating which assumptions are safe, and which aren't, but it's defintely worth it, since it then allows users to gracefully add in complexity as required.

A post will be up later talking about implementing a differentiable programming solution and another explaining the guts of what makes Flux and Zygote tick.

A Basic Optimisation Loop

For now, let's start with the classic example of optimising a random array to a different random array. It's just to illustrate how a simple iterative optimisation loop is expressed in Flux.

z = rand(3,3)
z′ = rand(3,3)

loss(x) = Flux.mse(z * x, z′ * x)
opt = Momentum()
ps = Params([z])  # z is an implicit parameter, and thus needs to be wrapped in the `Params` type.

for i = 1:10^5
  x = rand(3)
  gs = gradient(ps) do
    loss(x)
  end
  Flux.Optimise.update!(opt, ps, gs)
end

z  z′ # true

And just like that, we have moved z close to z′!

Adapting this to a custom type

Now, let's express this in terms of our own custom struct. For simplicity's sake, I am going to keep the fields of the struct Arrays, but they could be anything really.

import Base: +, -, *, /
import Base: isapprox
using MacroTools: @forward

mutable struct GG{T}
  a::T
  b::T
end

GG(a) = GG(a, a)

for op in (:+, :*, :-, :/)

  @eval @inline $(op)(a::GG, b::GG) = GG(broadcast($op, a.a, b.a), 
					broadcast($op, a.b, b.b))

  @eval @inline $(op)(a::GG, b) = GG(broadcast($op, a.a, b), 
				     broadcast($op, a.b, b))

  @eval @inline $(op)(b, a::GG) = GG(broadcast($op, a.a, b), 
				     broadcast($op, a.b, b))
end


@forward GG.a Base.size

Here, we've declared the struct, and defined some basic operations on how to handle the struct and its interaction with other types. Notice how we make use of Julia's excellent broadcasting infrastructure, and a bit of code interpolation to avoid repeating defitintions for all the operations we want to hold it to, (:+, :*, :-, :/) in this case. @inline also hints to the Julia compiler that these operations can be inlined easily, and it should try to do this optimisation if possible.

And just to hit the nail on the head, let's define some more primitives that could come in handy while optimisation. These are operations that a lot of folks would already be used to doing for mathematical compute, but we will extrapolate it to arbitrary structs, that don't immediately make sense to be "optimisable", in a manner of speaking.

Base.zero(a::GG) = GG(zero(a.a), zero(a.b))
Base.length(::GG) = 1
Base.:^(a::GG, i) = GG(a.a .^ i, a.b .^ i)

import Statistics: mean

mean(a::GG) = mean(a.a) + mean(a.b)
Base.sum(a::GG) = sum(a.a) + sum(a.b)

Base.isapprox(a::GG{T}, b::GG{T}) where T = all([isapprox(a.a, b.a), isapprox(a.b, b.b)])

One last thing that might be necessary to take advantage of Flux's optimisers is to teach it what to do with the GG struct. We can extend it to just call update on all the fields of the struct.

function Flux.Optimise.update!(opt, x::T, gs, fs = fieldnames(T)) where {T<:GG}
  gs = gs.x
  for f in fs
    Flux.Optimise.update!(opt, getfield(x,f), getfield(gs,f))
  end
end

And with that, we should be ready to do our optimisation.

Let's define two instances of our GG struct that we'd like to optimise.

a = GG(rand(3,3), rand(3,3))
b = GG(rand(3,3), rand(3,3))

And we will use the same Momentum optimiser and mean-squared-error loss.

opt = Momentum()
for i = 1:10^5
  gs = gradient(a) do x
    sum((x-b) * (x-b)) / prod(size(x))
  end
  Flux.Optimise.update!(opt, a, gs[1])
end

a  b # true

With this we have optimised a struct to another. Now we can take this concept and apply it to struct than a simple random array.

Another thing to note here is the complete lack of need of any call to Params in this case. This is because all of our parameters have been made explicit via passing a.

To give some context on the discussion earlier; the operations such as sum, prod, size, - etc are visible to Flux as valid operators to the parameters (a) and it looks into the implementation that we use for these transforms, to come up with valid adjoint methods. Think of it as the pulling motion from the shoelace example. Using these, it accumulates the gradients from all the operations, and finally returns them, keeping the structure of the paramters intact. This allows us to treat them as instances of the same type as usual, and finally optimise on them.

Optimising Colours

With the example done, let's try optimising colours. This is going to get fun! This example is taken from some of our work in the differentiable programming examples that we present here.

target = RGB(1, 0, 0)
colour = RGB(1, 1, 1)

function update_colour(c, Δ, η = 0.001)
   return RGB(
        c.r - η*Δ.r,
        c.g - η*Δ.g,
        c.b - η*Δ.b,
    )
end

for idx in 1:51
    global colour, target

    # Calculate gradients
    grads = Zygote.gradient(colour) do y
        colordiff(target, y)
    end
    # Update colour
    colour = update_colour(colour, grads[1])
    if idx % 5 == 1
        @info idx, colour
    end
end

Here our struct is just the RGB taken from the Colors.jl package. Again, the trick is to have meaningful operations defined on our type, based on the operations we will hit while calculating our loss function. The function colordiff already gives us the distance between two colours. Notice, that here we employ a simple Descent optimiser, which with enough iterations will just diverge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment