In the Python
module:
let autograd = Python.import("jax.grad")
extension PythonObject {
@differentiable(wrt: args, vjp: _vjpDynamicallyCall)
@discardableResult
➜ retro git:(c-api) ✗ make retro-c | |
[ 70%] Built target zip | |
[ 70%] Built target pce-submodule | |
[ 70%] Built target pce | |
[ 70%] Built target gba-submodule | |
[ 70%] Built target gba | |
[ 70%] Built target nes-submodule | |
[ 75%] Built target nes | |
[ 75%] Built target gb-submodule | |
[ 75%] Generating retro/cores/gambatte_libretro.dylib |
sil hidden @AD__$s11leak_simple6apply2_2to10TensorFlow0E0VySfGx_AGtAD5LayerRzAG5InputRtzAG6OutputRtzlF__adjoint_src_0_wrt_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Layer, τ_0_0.Input == Tensor<Float>, τ_0_0.Output == Tensor<Float>> (@guaranteed Tensor<Float>, @guaranteed _AD__$s11leak_simple6apply2_2to10TensorFlow0E0VySfGx_AGtAD5LayerRzAG5InputRtzAG6OutputRtzlF__Type__src_0_wrt_1<τ_0_0>) -> @owned Tensor<Float> { | |
// %0 // users: %13, %46, %2 | |
// %1 // user: %23 | |
bb0(%0 : $Tensor<Float>, %1 : $_AD__$s11leak_simple6apply2_2to10TensorFlow0E0VySfGx_AGtAD5LayerRzAG5InputRtzAG6OutputRtzlF__Type__src_0_wrt_1<τ_0_0>): | |
retain_value %0 : $Tensor<Float> // id: %2 | |
%3 = alloc_stack $Tensor<Float> // users: %25, %13, %47, %45, %4 | |
%4 = begin_access [init] [static] [no_nested_conflict] %3 : $*Tensor<Float> // users: %8, %10 | |
%5 = metatype $@thin Tensor<Float>.Type // user: %7 | |
// function_ref |
Hi all, @dan-zheng and I wrote a proposal to introduce static callables to Swift. This proposal is also available as a gist here. We'd love to hear your feedback.
// Function-as-a-differentiable-type rule: | |
// Tangent space: ((T...) -> U...)' = AnyDerivative | |
// Cotangent space: ((T...) -> U...)'* = AnyDerivative | |
// Why? Because when a function value is varying, what's varying is it's context. | |
// In general cases, we need this to be a constrained existential with an | |
// `AdditiveArithmetic` conformance for its `.zero` and `+`, and `Differentiable` | |
// for being able to transpose between differential and a pullback. | |
// New associated function type calculation rules: | |
// original: (T...) -> (U...) |
// Function-as-a-differentiable-type rule: | |
// Tangent space: ((T...) -> U...)' = Any | |
// Cotangent space: ((T...) -> U...)'* = Any | |
// Why? Because when a function value is varying, what's varying is it's context. | |
// In general cases, we need this to be a constrained existential with an | |
// `AdditiveArithmetic` conformance for its `.zero` and `+`, and `Differentiable` | |
// for being able to transpose between differential and a pullback. | |
// New associated function type calculation rules: | |
// original: (T...) -> (U...) |
// Type-erased box. | |
fileprivate class AnyDerivativeBox : Differentiable & AdditiveArithmetic { | |
public typealias TangentVector = AnyDerivativeBox | |
public typealias CotangentVector = AnyDerivativeBox | |
public typealias AllDifferentiableVariables = AnyDerivativeBox | |
public static func == (lhs: AnyDerivativeBox, rhs: AnyDerivativeBox) -> Bool { | |
fatalError("Must override") | |
} | |
public static var zero: Self { |
// AD__$s4test15MNISTClassifierV7applied2to10TensorFlow0E0VySfGAI_tF__adjoint_src_0_wrt_0_1 | |
sil hidden @AD__$s4test15MNISTClassifierV7applied2to10TensorFlow0E0VySfGAI_tF__adjoint_src_0_wrt_0_1 : $@convention(method) (@guaranteed Tensor<Float>, @guaranteed _AD__$s4test15MNISTClassifierV7applied2to10TensorFlow0E0VySfGAI_tF__Type__src_0_wrt_0_1) -> (@owned MNISTClassifier.AllDifferentiableVariables, @owned Tensor<Float>) { | |
// %0 // users: %41, %4, %2 | |
// %1 // users: %29, %25, %21, %17, %13, %11, %7, %3 | |
bb0(%0 : $Tensor<Float>, %1 : $_AD__$s4test15MNISTClassifierV7applied2to10TensorFlow0E0VySfGAI_tF__Type__src_0_wrt_0_1): | |
retain_value %0 : $Tensor<Float> // id: %2 | |
%3 = struct_extract %1 : $_AD__$s4test15MNISTClassifierV7applied2to10TensorFlow0E0VySfGAI_tF__Type__src_0_wrt_0_1, #_AD__$s4test15MNISTClassifierV7applied2to10TensorFlow0E0VySfGAI_tF__Type__src_0_wrt_0_1.pullback_7 // user: %4 | |
%4 = apply %3(%0) : |