Skip to content

Instantly share code, notes, and snippets.

@marcrasi
Created November 15, 2018 01:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marcrasi/c0832db2cf71b5f5d8709aab89c96be9 to your computer and use it in GitHub Desktop.
Save marcrasi/c0832db2cf71b5f5d8709aab89c96be9 to your computer and use it in GitHub Desktop.
protocol Proto {
@differentiable(reverse)
func f(_ x: Float) -> Float
}
func callF<T: Proto>(_ t: T, _ x: Float) -> Float {
return t.f(x)
}
func gradFWrtX<T: Proto>(_ t: T, at x: Float) -> Float {
return (#gradient(callF, wrt: .1) as (T, Float) -> Float)(t, x)
}
struct MultiplyConstant : Proto {
let constant: Float
@differentiable(reverse, adjoint: fAdj)
func f(_ x: Float) -> Float {
return constant * x
}
func fAdj(_ x: Float, _ origResult: Float, _ seed: Float) -> Float {
return constant * seed
}
}
print(gradFWrtX(MultiplyConstant(constant: 10), at: 0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment