Skip to content

Instantly share code, notes, and snippets.

@rxwei
Last active November 26, 2022 10:48
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rxwei/bf4fbc5c6c9655b0d4880eb8d12a9f85 to your computer and use it in GitHub Desktop.
Save rxwei/bf4fbc5c6c9655b0d4880eb8d12a9f85 to your computer and use it in GitHub Desktop.
Autograd in Swift

Make Swift AD support all numpy functions via Autograd/JAX

In the Python module:

let autograd = Python.import("jax.grad")

extension PythonObject {
  @differentiable(wrt: args, vjp: _vjpDynamicallyCall)
  @discardableResult
  func dynamicallyCall(
    withArguments args: [PythonObject] = []
  ) -> PythonObject {
    return try! throwing.dynamicallyCall(withArguments: args as [PythonConvertible])
  }

  func _vjpDynamicallyCall(_ args: [PythonObject])
    -> (PythonObject, (PythonObject) -> Array<PythonObject>.CotangentVector) {
    let vjp = autograd.vjp(self)
    let result = vjp.dynamicallyCall(withArguments: args)
    let (pb, y) = result.tuple2
    return (y, { v in Array<PythonObject>.DifferentiableView([pb(v)]) })
  }
}

extension PythonObject : AdditiveArithmetic {
  public static var zero: PythonObject {
    return 0
  }
}

extension PythonObject : Differentiable {
  public typealias TangentVector = PythonObject
  public typealias CotangentVector = PythonObject
  public typealias AllDifferentiableVariables = PythonObject

  public var allDifferentiableVariables: PythonObject {
    get { return self }
    set { self = newValue }
  }

  public func moved(along direction: PythonObject) -> PythonObject {
    return self + direction
  }

  public func tangentVector(from cotangent: PythonObject) -> PythonObject {
    return self
  }
}

It works!

let np = Python.import("autograd.numpy")
let (value, pullback) = valueWithPullback(at: [1.0]) { x in
  np.cos.dynamicallyCall(withArguments: x)
}
print(value)
let g = pullback(3.0)
print(g.base)
0.5403023058681398
[-2.5244129544236893]

The reason we need to call np.cos.dynamicallyCall right now is because AD does not know how to differentiate array literal construction yet. We need to teach AD to peer through _allocateUninitializedArray. After this, we should be able to anything we want:

let np = Python.import("jax.numpy")
let (value, pullback) = valueWithPullback(at: 1.0) { x in
  np.cos(x) + np.cos(x)
}
print(value)
let g = pullback(3.0)
print(g)

Or, of course, mix differentiable Swift code with Python code!

@rxwei
Copy link
Author

rxwei commented Jun 20, 2019

jax.vjp is variadic. We need to manually call dynamicallyCall(withArguments:).

    @differentiating(dynamicallyCall, wrt: args)
    func vjp(_ args: [PythonObject])
        -> (value: PythonObject,
            pullback: (PythonObject) -> [PythonObject].TangentVector) {
        let (_, pb) = jax.vjp.dynamicallyCall(withArguments: [self] + args).tuple2
        return (value: y, pullback: { v in let dx = pb(v); return Array.TangentVector(Array(dx) ?? [dx]) })
    }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment