In the Python
module:
let autograd = Python.import("jax.grad")
extension PythonObject {
@differentiable(wrt: args, vjp: _vjpDynamicallyCall)
@discardableResult
func dynamicallyCall(
withArguments args: [PythonObject] = []
) -> PythonObject {
return try! throwing.dynamicallyCall(withArguments: args as [PythonConvertible])
}
func _vjpDynamicallyCall(_ args: [PythonObject])
-> (PythonObject, (PythonObject) -> Array<PythonObject>.CotangentVector) {
let vjp = autograd.vjp(self)
let result = vjp.dynamicallyCall(withArguments: args)
let (pb, y) = result.tuple2
return (y, { v in Array<PythonObject>.DifferentiableView([pb(v)]) })
}
}
extension PythonObject : AdditiveArithmetic {
public static var zero: PythonObject {
return 0
}
}
extension PythonObject : Differentiable {
public typealias TangentVector = PythonObject
public typealias CotangentVector = PythonObject
public typealias AllDifferentiableVariables = PythonObject
public var allDifferentiableVariables: PythonObject {
get { return self }
set { self = newValue }
}
public func moved(along direction: PythonObject) -> PythonObject {
return self + direction
}
public func tangentVector(from cotangent: PythonObject) -> PythonObject {
return self
}
}
It works!
let np = Python.import("autograd.numpy")
let (value, pullback) = valueWithPullback(at: [1.0]) { x in
np.cos.dynamicallyCall(withArguments: x)
}
print(value)
let g = pullback(3.0)
print(g.base)
0.5403023058681398
[-2.5244129544236893]
The reason we need to call np.cos.dynamicallyCall
right now is because AD does not know
how to differentiate array literal construction yet. We need to teach AD to peer through
_allocateUninitializedArray
. After this, we should be able to anything we want:
let np = Python.import("jax.numpy")
let (value, pullback) = valueWithPullback(at: 1.0) { x in
np.cos(x) + np.cos(x)
}
print(value)
let g = pullback(3.0)
print(g)
Or, of course, mix differentiable Swift code with Python code!
jax.vjp
is variadic. We need to manually calldynamicallyCall(withArguments:)
.