Skip to content

Instantly share code, notes, and snippets.

@marcrasi
Created June 15, 2020 22:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marcrasi/51844c78bbac58d4afc23651fc33d512 to your computer and use it in GitHub Desktop.
Save marcrasi/51844c78bbac58d4afc23651fc33d512 to your computer and use it in GitHub Desktop.
import TensorFlow
// MARK: Example function and its transformed version.
func cube(_ x: Float) -> Float {
return x.product(x).product(x)
}
func cubeT<A: WrapsFloat>(_ x: A) -> A {
return x.product(x).product(x)
}
func dup(_ x: Float) -> Vector2 {
return Vector2(x, x)
}
func dupT<B: WrapsVector2>(_ x: B.WrappedX) -> B where B.WrappedX == B.WrappedY {
return B(wrappedVector2: x, x)
}
// MARK: - Immediate forward derivative.
struct ValueWithTangent<Value, Tangent: Vector> {
var value: Value
var tangent: Tangent
}
func differential<A, AT, B, BT>(
at x: A,
direction t: AT,
of f: (ValueWithTangent<A, AT>) -> ValueWithTangent<B, BT>
) -> BT {
return f(ValueWithTangent(value: x, tangent: t)).tangent
}
public func example1() {
// Derivative should be 3 * 1^2 = 3:
print(differential(at: Float(1), direction: Float(1), of: cubeT)) // => 3
//print(differential(at: Float(1), direction: Float(1), of: dupT) as ValueWithTangent<Vector2, Vector2>) // => Vector2(x: 1, y: 1)
}
// MARK: - Immediate higher order forward derivative.
public func example2() {
let direction = ValueWithTangent(
value: ValueWithTangent(value: Float(1), tangent: Float(1)),
tangent: Float(1)
)
print(cubeT(direction))
}
//struct HigherOrderTangent<Value: ValueWithTangent>: ValueWithTangent {
// var value: Value
//
// typealias Tangent = Value.Tangent
// var tangent: Value.Tangent
//
// static func zeroTangent(_ value: Value) -> Tangent {
// return Value.zeroTangent(value.value)
// }
// static func sum(_ t1: Tangent, _ t2: Tangent) -> Tangent {
// return Value.sum(t1, t2)
// }
// static func scale(_ scalar: Float, _ t: Tangent) -> Tangent {
// return Value.scale(scalar, t)
// }
//}
//
//func immediateTangentWeirdInitializer(
// _ x: ImmediateTangent<Float>,
// _ tangent: Float
//) -> HigherOrderTangent<ImmediateTangent<Float>> {
// HigherOrderTangent(value: x, tangent: tangent)
//}
//
//extension HigherOrderTangent where Value == ImmediateTangent<Float> {
// var tangentT: ImmediateTangent<Float> {
// return value
// }
//}
//
//public func example2() {
// func dCube(_ x: Float, _ direction: Float) -> Float {
// return cubeT(ImmediateTangent(value: x, tangent: direction)).tangent
// }
//
// func dCubeT(_ x: ImmediateTangent<Float, Float>, _ direction: Float) -> ImmediateTangent2<Float, Float>.Tangent {
// return cubeT(ImmediateTangent2(value: x.value, tangent: ImmediateTangent2.Tangent(t1: x.tangent, t2: direction))).tangent
// }
//
// // when doing lifting, we need to define "WrapsImmediateTangent" and define the initializer and `.tangent` methods for the initializer!!!
//
//// func dCubeT(_ x: ImmediateTangent<Float>, _ direction: Float) -> ImmediateTangent<Float> {
//// return cubeT(immediateTangentWeirdInitializer(x, direction)).tangentT
//// }
//
//// let d2 = differential(at: Float(1), direction: Float(1)) {
//// return dCubeT($0, Float(1))
//// }
//// print(d2)
//
// print(dCubeT(ImmediateTangent(value: Float(1), tangent: Float(1)), Float(1)))
//}
//
//// MARK: - Immediate forward mode higher order differentiation.
//
//struct ImmediateTangentT<ValueT: ValueWithTangent>: ValueWithTangent {
// var valueT: ValueT
//
// typealias Value = ValueT.Value
// var value: Value { return valueT.value }
//
// typealias Tangent = ValueT.Tangent
// var tangent: ValueT.Tangent
//
// init(value: Value, tangent: Tangent) {
// self.valueT = ValueT(value: value, tangent: ValueT.zeroTangent(value))
// self.tangent = tangent
// }
//
// init(valueT: ValueT, tangent: Tangent) {
// self.valueT = valueT
// self.tangent = tangent
// }
//
// func tangentT() -> ImmediateTangentT<Tangent> where ValueT.Tangent: ValueWithTangent {
// fatalError()
// }
//
// static func zeroTangent(_ value: Value) -> Tangent {
// return ValueT.zeroTangent(value)
// }
// static func sum(_ t1: Tangent, _ t2: Tangent) -> Tangent {
// return ValueT.sum(t1, t2)
// }
// static func scale(_ scalar: Float, _ t: Tangent) -> Tangent {
// return ValueT.scale(scalar, t)
// }
//}
//
//func differentialT<A, B, ATV: ValueWithTangent>(
// at x: ATV,
// direction t: ImmediateTangent<A>.Tangent,
// of f: (ImmediateTangentT<ATV>) -> ImmediateTangentT<ImmediateTangent<B>>
//) -> ImmediateTangentT<ImmediateTangent<B>.Tangent>
//where ATV.Value == ImmediateTangent<A>.Value, ATV.Tangent == ImmediateTangent<A>.Tangent {
// return f(ImmediateTangentT(valueT: x, tangent: t)).tangentT()
//}
//
//
////func differentialT<AT: ValueWithTangent, BT: ValueWithTangent>(
//// at x: AT,
//// direction t: AT.Tangent,
//// of f: (ImmediateTangentT<AT>) -> ImmediateTangentT<BT>
////) -> ImmediateTangentT<BT.Tangent> {
//// return f(ImmediateTangentT(valueT: x, tangent: t)).tangentT()
////}
//
//func example2() {
// let t = differentialT(at: Float(1), direction: Float(1)) {
// differentialT(at: $0, direction: Float(1), of: cubeT)
// }
// print(t)
//}
// MARK: - Primitive functions and their "derivatives".
protocol WrapsFloat {
func scaled(byFloat: Float) -> Self
func product(_ other: Self) -> Self
}
extension Float: WrapsFloat {
func scaled(byFloat: Float) -> Self {
return self * byFloat
}
func product(_ other: Self) -> Self {
return self * other
}
}
extension ValueWithTangent: WrapsFloat where Value: WrapsFloat {
func scaled(byFloat: Float) -> Self {
return Self(value: value.scaled(byFloat: byFloat), tangent: tangent.scaleT(by: byFloat))
}
func product(_ other: Self) -> Self {
let productValue = self.value.product(other.value)
let productTangent1 = other.tangent.scaleT(by: self.value)
let productTangent2 = self.tangent.scaleT(by: other.value)
let productTangent = productTangent1 + productTangent2
print("Product of \(self) * \(other), tangent is \(productTangent)")
return Self(value: productValue, tangent: productTangent)
}
}
struct Vector2: Vector {
var x, y: Float
init(_ x: Float, _ y: Float) {
self.x = x
self.y = y
}
static func + (_ lhs: Self, _ rhs: Self) -> Self {
return Self(lhs.x + rhs.x, lhs.y + rhs.y)
}
func scaleT<Scalar: WrapsFloat>(by scalar: Scalar) -> Self {
return Self(x.scaleT(by: scalar), y.scaleT(by: scalar))
}
static var zero: Self { return Self(0, 0) }
}
protocol WrapsVector2 {
associatedtype WrappedX: WrapsFloat
associatedtype WrappedY: WrapsFloat
var wrappedVector2: Vector2 { get }
init(wrappedVector2 x: WrappedX, _ y: WrappedY)
var wrappedVector2x: WrappedX { get }
var wrappedVector2y: WrappedY { get }
}
extension Vector2: WrapsVector2 {
var wrappedVector2: Vector2 { return self }
init(wrappedVector2 x: Float, _ y: Float) {
self.x = x
self.y = y
}
var wrappedVector2x: Float { return self.x }
var wrappedVector2y: Float { return self.y }
}
extension ValueWithTangent: WrapsVector2
where Value: WrapsVector2,
Tangent: WrapsVector2, Tangent.WrappedX: Vector, Tangent.WrappedY: Vector {
typealias WrappedX = ValueWithTangent<Value.WrappedX, Tangent.WrappedX>
typealias WrappedY = ValueWithTangent<Value.WrappedY, Tangent.WrappedY>
var wrappedVector2: Vector2 { return value.wrappedVector2 }
init(wrappedVector2 x: WrappedX, _ y: WrappedY) {
self.value = Value(wrappedVector2: x.value, y.value)
self.tangent = Tangent(wrappedVector2: x.tangent, y.tangent)
}
var wrappedVector2x: WrappedX { return WrappedX(value: value.wrappedVector2x, tangent: tangent.wrappedVector2x) }
var wrappedVector2y: WrappedY { return WrappedY(value: value.wrappedVector2y, tangent: tangent.wrappedVector2y) }
}
// MARK: - Other helpers.
protocol Vector {
static func + (_ lhs: Self, _ rhs: Self) -> Self
func scaleT<Scalar: WrapsFloat>(by scalar: Scalar) -> Self
static var zero: Self { get }
}
extension Float: Vector {
func scaleT<Scalar: WrapsFloat>(by scalar: Scalar) -> Self {
return scalar.scaled(byFloat: self)
}
static var zero: Self { return 0 }
}
extension ValueWithTangent: Vector where Value: Vector {
static func + (_ lhs: Self, _ rhs: Self) -> Self {
return Self(value: lhs.value + rhs.value, tangent: lhs.tangent + rhs.tangent)
}
func scaleT<Scalar: WrapsFloat>(by scalar: Scalar) -> Self {
return Self(value: self.value.scaleT(by: scalar), tangent: self.tangent.scaleT(by: scalar))
}
static var zero: Self {
return Self(value: Value.zero, tangent: Tangent.zero)
}
}
import Foundation
// Original function:
func f(_ x: Float, _ y: Float) -> Vector2 {
let a = multiplied(x, x)
let b = multiplied(x, y)
let r = Vector2(a, b)
return r
}
func g(_ x: Float) -> Float {
return multiplied(x, x)
}
// Transformed function:
func fT<InputT: TangentProtocol, OutputT: TangentProtocol>(
_ x: ValueWithTangent<Float, InputT>,
_ y: ValueWithTangent<Float, InputT>
) -> ValueWithTangent<Vector2, OutputT> {
let a: ValueWithTangent<Float, InputT> = multipliedT(x, x)
let b: ValueWithTangent<Float, InputT> = multipliedT(x, y)
let r: ValueWithTangent<Vector2, OutputT> = Vector2.initT(a, b)
return r
}
func gT<InputT: TangentProtocol, OutputT: TangentProtocol>(_ x: ValueWithTangent<Float, InputT>) -> ValueWithTangent<Float, OutputT> {
return multipliedT(x, x)
}
// "Immediate" forward mode differentiation
public func evaluateForwardMode() {
let partialX: ValueWithTangent<Vector2, Vector2> =
valueWithTangent(of: fT, at: 5, 10, tangent: Float(1), Float(0))
let partialY: ValueWithTangent<Vector2, Vector2> =
valueWithTangent(of: fT, at: 5, 10, tangent: Float(0), Float(1))
print(partialX.tangent)
print(partialY.tangent)
}
// Reverse mode differentiation
public func evaluateReverseMode() {
let (_, pullback) = valueWithPullback(of: gT, at: 2)
print(pullback(1))
}
func valueWithPullback<A, R>(
of f: (ValueWithTangent<A, Pullback<A, A>>) -> ValueWithTangent<R, Pullback<A, R>>,
at x: A
) -> (value: R, pullback: (R) -> A) {
let vwt = valueWithTangent(of: f, at: x, tangent: Pullback<A, A>({ $0 }))
return (vwt.value, vwt.tangent.pullback)
}
//
//// Second derivative!!
//
//public func evaluateSecondDerivative() {
// let g2: (ValueWithTangent<Float, Float>) -> ValueWithTangent<Float, ValueWithTangent<Float, SelfTangent<Float>>> = gT
// let d2 =
// secondDerivative(of: g2, at: Float(5), tangent1: Float(1), tangent2: Float(1))
// print(d2)
//}
//
//
/// The pullback of a function `(A) -> B`.
/// Only supports types that are tangents of themselves, but I think this restriction could be lifted with more work.
struct Pullback<A: TangentProtocol, B>: TangentProtocol where A.TangentOf == A {
let pullback: (B) -> A
init(_ pullback: @escaping (B) -> A) {
self.pullback = pullback
}
func adding(_ other: Self) -> Self {
return Pullback( { self.pullback($0).adding(other.pullback($0)) })
}
func scaling(by scalar: Float) -> Self {
return Pullback( { self.pullback($0).scaling(by: scalar) })
}
typealias TangentOf = B
// e.g unprojecting (Float) -> A to (Vector2) -> A
static func unprojected<K, T: TangentProtocol>(_ keyPath: WritableKeyPath<TangentOf, K>, _ t: T) -> Self {
return Pullback({ (bt: B) -> A in
let t = t as! Pullback<A, K>
return t.pullback(bt[keyPath: keyPath])
})
}
static var zero: Self {
return Pullback( { _ in A.zero } )
}
}
//
//// Support code
//
//struct Vector2 {
// var x, y: Float
// init(_ x: Float, _ y: Float) {
// self.x = x
// self.y = y
// }
//}
//
//func multiplied(_ x: Float, _ y: Float) -> Float {
// return x * y
//}
//
protocol TangentProtocol {
func adding(_ other: Self) -> Self
func scaling(by scalar: Float) -> Self
associatedtype TangentOf
//func projected<T>(_ keyPath: KeyPath<TangentOf, T>) -> T
static func unprojected<K, T: TangentProtocol>(_ keyPath: WritableKeyPath<TangentOf, K>, _ t: T) -> Self
static var zero: Self { get }
}
//
//struct ValueWithTangent<Value, Tangent: TangentProtocol> where Tangent.TangentOf == Value {
// var value: Value
// var tangent: Tangent
// init(_ value: Value, _ tangent: Tangent) {
// self.value = value
// self.tangent = tangent
// }
//}
//
//extension ValueWithTangent: TangentProtocol where Value: TangentProtocol {
// func adding(_ other: Self) -> Self {
// return Self(value, tangent.adding(other.tangent))
// }
// func scaling(by scalar: Float) -> Self {
// return Self(value, tangent.scaling(by: scalar))
// }
//
// typealias TangentOf = Value.TangentOf
// //func projected<T>(_ keyPath: KeyPath<TangentOf, T>) -> T
// static func unprojected<K, T: TangentProtocol>(_ keyPath: WritableKeyPath<TangentOf, K>, _ t: T) -> Self {
//// print(keyPath)
//// print(t)
//// print(self)
//// fatalError("not implemented")
////
// // TODO: ????
// return ValueWithTangent(
// //Value.unprojected(keyPath, t),
// Value.zero,
// Tangent.unprojected(keyPath as! WritableKeyPath<Value, T>, t))
// }
//
// static var zero: Self {
// return ValueWithTangent(Value.zero, Tangent.zero)
// }
//}
//
//func multipliedT<T: TangentProtocol, OutputT: TangentProtocol>(
// _ x: ValueWithTangent<Float, T>,
// _ y: ValueWithTangent<Float, T>
//) -> ValueWithTangent<Float, OutputT> {
// let xy = x.value * y.value
// let xy_t = x.tangent.scaling(by: y.value).adding(y.tangent.scaling(by: x.value))
// return ValueWithTangent(xy, OutputT.unprojected(\Float.self, xy_t))
//}
//
//extension Vector2 {
// static func initT<InputT: TangentProtocol, OutputT: TangentProtocol>(
// _ x: ValueWithTangent<Float, InputT>,
// _ y: ValueWithTangent<Float, InputT>
// ) -> ValueWithTangent<Vector2, OutputT> {
// let vector2 = Vector2(x.value, y.value)
// let x_t_unprojected = OutputT.unprojected(\Vector2.x, x.tangent)
// let y_t_unprojected = OutputT.unprojected(\Vector2.y, y.tangent)
// let vector2_t = x_t_unprojected.adding(y_t_unprojected)
// return ValueWithTangent(vector2, vector2_t)
// }
//}
//
//extension Float: TangentProtocol {
// func adding(_ other: Self) -> Self {
// return self + other
// }
// func scaling(by scalar: Float) -> Self {
// return self * scalar
// }
//
// typealias TangentOf = Self
// static func unprojected<K, T: TangentProtocol>(
// _ keyPath: WritableKeyPath<TangentOf, K>,
// _ t: T
// ) -> Self {
// var result = Float(0)
// result[keyPath: keyPath] = t as! K
// return result
// }
//}
//
//extension Vector2: TangentProtocol {
// func adding(_ other: Self) -> Self {
// return Vector2(x + other.x, y + other.y)
// }
// func scaling(by scalar: Float) -> Self {
// return Vector2(scalar * x, scalar * y)
// }
//
// typealias TangentOf = Self
// static func unprojected<K, T: TangentProtocol>(
// _ keyPath: WritableKeyPath<TangentOf, K>,
// _ t: T
// ) -> Self {
// var result = Vector2(0, 0)
// result[keyPath: keyPath] = t as! K
// return result
// }
// static var zero: Self { return Vector2(0, 0) }
//}
//
//func valueWithTangent<A, R, AT, RT>(
// of f: (ValueWithTangent<A, AT>) -> ValueWithTangent<R, RT>,
// at x: A,
// tangent tx: AT
//) -> ValueWithTangent<R, RT> {
// return f(ValueWithTangent(x, tx))
//}
//
//struct SelfTangent<Tan: TangentProtocol>: TangentProtocol
//{
// var t: Tan
// func adding(_ other: Self) -> Self {
// return Self(t: t.adding(other.t))
// }
// func scaling(by scalar: Float) -> Self {
// return Self(t: t.scaling(by: scalar))
// }
//
// typealias TangentOf = Tan
// //func projected<T>(_ keyPath: KeyPath<TangentOf, T>) -> T
// static func unprojected<K, T: TangentProtocol>(_ keyPath: WritableKeyPath<TangentOf, K>, _ t: T) -> Self {
// return Self(t: Tan.unprojected(keyPath as! WritableKeyPath<Tan.TangentOf, K>, t))
// }
// static var zero: Self { Self(t: Tan.zero) }
//}
//
//func secondDerivative<A, R, AT, RT>(
// of f: (ValueWithTangent<A, AT>) -> ValueWithTangent<R, ValueWithTangent<RT, SelfTangent<RT>>>,
// at x: A,
// tangent1 tx1: AT,
// tangent2 tx2: AT
//) -> RT {
// func firstDerivative(_ x2: ValueWithTangent<A, AT>) -> ValueWithTangent<RT, SelfTangent<RT>> {
// return valueWithTangent(of: f, at: x, tangent: tx1).tangent
// }
// return valueWithTangent(of: firstDerivative, at: x, tangent: tx2).tangent.t
//}
//
//func valueWithTangent<A, B, R, AT, BT, RT>(
// of f: (ValueWithTangent<A, AT>, ValueWithTangent<B, BT>) -> ValueWithTangent<R, RT>,
// at x: A, _ y: B,
// tangent tx: AT, _ ty: BT
//) -> ValueWithTangent<R, RT> {
// return f(ValueWithTangent(x, tx), ValueWithTangent(y, ty))
//}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment