Skip to content

Instantly share code, notes, and snippets.

@rxwei
Last active February 27, 2019 06:13
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save rxwei/febd5b1fa3f0af77ed193db4b60071a1 to your computer and use it in GitHub Desktop.
Automatic differentiation roadmap (unfinished)

Swift Automatic Differentiation Roadmap

Richard Wei

January 2019

Progress

December 2017: Design from scratch

Nothing exists other than prior work on source code / IR transformation: Tangent, DLVM, etc.

This is the beginning of the design of a language-integrated automatic differnetiation feature, and perhaps the first ever. Problems to figure out lie across all parts of Swift: syntax, type checking, the standard library, generics, resilience, performance, etc.

March 2018: Primitive differentiation

What worked

  • Forwarding primitive registration: @differentiable(reverse, adjoint: ...).

  • Differential operators: #gradient, #valueAndGradient.

    Primitive derivatives are registered for arithemtic operators in the Swift standard library and common math operators in the TensorFlow library.

What didn't work

  • Differentiating a nested function call is not supported.

    #gradient({ (x: Float) -> Float in
      return x + x
    })
  • Conditional differentiability cannot be expressed.

    Generic vector types are often compatible with integer scalar types, but integers are not differentiable types. We want Vector to only be differentiable when Scalar conforms to differentiable.

    extension Vector where Scalar : FloatingPoint {
      @differentiable(adjoint: ... where Scalar : FloatingPoint)
      static func * (x: Vector, y: Vector) -> Vector {
        ...
      }
    }

September 2018: The differentiation transform

What worked

  • Reverse-mode differentiation transform.

    • Primal and adjoint functions emitted by the compiler are treated as the canonical representation of a differentiable function.
    • For nested function calls, nested primal value structures are emitted for efficiency.
  • Methods can be differentiated as indirect calls.

    #gradient({ (p: Point, x: Float) in p.foo(x) })

What didn't work

  • Differentiating an opaque closure.

    We can't inspect the body of an opaque closure. The early semantics of automatic differentiation relied on serialization. This has multiple problems, the most prominent one of which is the lack of modular compilation/verification.

    func foo(f: (Float) -> Float) {
      gradient(of: f) // error: cannot differentiate opaque closures
    }
  • Differentiating with respect to structure types whose values are not all differentiable.

  • Derived conformances to VectorNumeric for user-defined structure types.

December 2018: Generalized types and generalized differentiability

What worked

  • Differentiable types via Differentiable protocol conformance.

    The Differentiable protocol has associated types TangentVector and CotangentVector, which correspond to the type of directional derivatives and the type of gradients, respectively. Allowing derivatives to be in different spaces makes it possible to improve many existing systems, e.g. quantized training, differentiating w.r.t. type-safe orthogonal weight matrices, and differentiating w.r.t. a structure whose fields are not all Differentiable types.

  • Differentiable function types (e.g. @autodiff (T) -> U).

    Subtyping for functions was traditionally used for

    • Enabled differentiating opaque differentiable functions.
    • Supported full opacity across module boundaries.
  • Differentiable protocol requirements (@differentiable in protocols).

    • Enabled machine learning libraries that use gradient-based optimization methods.

What didn't work

  • Generic functions cannot be differentiated. The limitation lies in two aspects:

    • Differentiating generic functions is not supported by the differentiation transform. In the following program, although every type along the differentiation path is loadable in SIL, these generic functions carry archetypes in their body and requires extra care in cloning generic signatures and mapping original bound genenric types to the derivative's declaration context.

      func foo<T>(x: Tensor<T>) -> Tensor<T> {
        return x
      }
      gradient(at: 2, in: foo) // error: cannot differentiate generic functions yet
    • Nested calls to a generic function will fall through activity analysis and result in zero derivatives. This is not just about what's supported and what's not -- it's about correctness.

      func foo<T>(x: T) -> T {
        return x
      }
      gradient(at: 2) { x in
        return foo(x + 1)
      } // Zero!

January 2019: JVPs/VJPs, side effects, and generics

What worked

  • Conditional function differentiability under generic constraints.

    This works very well with clear diagnostics.

    @differentiable(where T : Differentiable & Differentiable)
    func square<T>(_ x: Vector<T>) -> Vector<T> {
      return x * x
    }
  • Differentiating generic functions with parameters or results of known size.

    func foo<T>(x: Vector<T>) -> Vector<T> {
      return x + x
    }
    func bar<T>(x: Vector<T>) -> Vector<T> {
      return bar(x) + Vector(1)
    }
    gradient(at: 2, in: bar) // 2
  • Differentiating functions with mutable local variables.

    gradient(at: 2 as Float) { x in
      var a = x
      a = a + x
      return a + x
    } // 3

What didn't work

  • Memory management is bugged.

    Both memory leaks and use-after-free's exist. Memory leaks are blocking large workloads.

  • Functions with an inout parameter are not differentiable.

    This is mainly due to the lack of formal semantics that define the type signature of the JVP/VJP of a function with an inout parameter or the transformation rules.

  • Semantics of differentiable curried functions are not complete.

    An @autodiff attribute should be considered as applying to the arrow of each function, signifying that parameters on the left of the arrow are differentiation variables, and it should not applied any parameters on inner curry levels. This rule should be applied when any @autodiff function gets curried. In SIL, curry thunks for @autodiff functions should fall out of this rule.

    func curry<T : Differentiable, U : Differentiable>(
      f: @autodiff (T, U) -> V
    ) -> @autodiff (T) -> @autodiff (U) -> V {
      return { x in { y in f(x, y) } }
    }

Next steps

The next steps of Swift automatic differentiation consist of six large components:

  • control flow support,

  • higher-order differentiation,

  • forward-mode differentiation,

  • retroactive derivative registration,

  • side-effecting function derivatives, and

  • user experience enhancements in all areas.

Control flow support

TBD.

Higher-order differentiation

Higher-order differentiation requires the differentiation transform to emit and lower autodiff_function instructions iteratively through a work list until all raw autodiff_function instructions have been lowered to having associated functions.

Forward-mode differentiation

Forward-mode differentiation can be considered as a generalization of PrimalGen that calls the JVP instead of the VJP, along with a differential function emitter that simply chains the differentials in forward-order.

Retroactive derivative registration

Retroactive derivative registration enables the user to retroactively make any accessible function be differentiable by providing a JVP or VJP. This feature is a lot like extensions in Swift, but is more tied to functions rather than types.

Retroactive registration solves an important problem that initial users of Swift for TensorFlow are facing today: not being able to differentiate standard library math functions. While the source of this problem is because Python.

Here is an example usage of retroactive derivative registration, in its full cross-module flexibility.

In the Swift standard library:

public func log(_ x: Float) -> Float {
  return x + 1
}

In third-party library module MyLibrary:

@differentiating(Swift.log(_:))
@usableFromInline
func logVJP(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
  return (value: log(x), differential: { v / x })
}

In user module MyApplication:

deriavtive { x in log(x) } // Uses `MyLibrary.logVJP`.

Side-effecting function derivatives

User experience enhancements in all areas

Robustness improvements

  • Non-differentiability diagnostics in all cases differentiation is invoked: @differentiable attribute protocol conformances, differential operators, etc.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment