Skip to content

Instantly share code, notes, and snippets.

@dsyme
Last active February 27, 2020 14:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dsyme/924686f34bb103344213f8f5f4bbc721 to your computer and use it in GitHub Desktop.
Save dsyme/924686f34bb103344213f8f5f4bbc721 to your computer and use it in GitHub Desktop.
namespace DiffSharp
open DiffSharp.Backend
open DiffSharp.Util
[<CustomEquality; CustomComparison>]
type Tensor =
| Tensor of primalRaw:RawTensor
| TensorF of primal:Tensor * derivative:Tensor * nestingTag:uint32
| TensorR of primal:Tensor * derivative:(Tensor ref) * parentOperation:TensorOp * fanout:(uint32 ref) * nestingTag:uint32
member t.Primal =
match t with
| Tensor(_) -> t
| TensorF(tp,_,_) -> tp
| TensorR(tp,_,_,_,_) -> tp
member t.PrimalRaw =
match t with
| Tensor(tp) -> tp
| TensorF(tp,_,_) -> tp.PrimalRaw
| TensorR(tp,_,_,_,_) -> tp.PrimalRaw
member t.Depth =
let rec depth x d =
match x with
| Tensor(_) -> d
| TensorF(tp,_,_) -> depth tp (d + 1)
| TensorR(tp,_,_,_,_) -> depth tp (d + 1)
depth t 0
member t.Derivative
with get() =
match t with
| Tensor(_) -> failwith "Cannot get derivative of constant Tensor"
| TensorF(_,td,_) -> td
| TensorR(_,td,_,_,_) -> !td
and set(value) =
match t with
| Tensor(_) -> failwith "Cannot set derivative of constant Tensor"
| TensorF(_) -> failwith "Cannot set derivative of TensorF"
| TensorR(_,td,_,_,_) -> td := value
member t.Fanout
with get() =
match t with
| Tensor(_) -> failwith "Cannot get fanout of constant Tensor"
| TensorF(_) -> failwith "Cannot get fanout of TensorF"
| TensorR(_,_,_,f,_) -> !f
and set(value) =
match t with
| Tensor(_) -> failwith "Cannot set fanout of constant Tensor"
| TensorF(_) -> failwith "Cannot set fanout of TensorF"
| TensorR(_,_,_,f,_) -> f := value
member t.ForwardDiff(derivative:Tensor, ?tag:uint32) =
let tag = defaultArg tag GlobalNestingLevel.Current
if t.Shape = derivative.Shape then TensorF(t, derivative, tag) else failwithf "Expecting derivative of same shape with primal. primal: %A, derivative: %A" t derivative
member t.ReverseDiff(?tag:uint32) =
let tag = defaultArg tag GlobalNestingLevel.Current
TensorR(t, ref (t.Zero()), NewT, ref 0u, tag)
member t.NoDiff() = Tensor(t.PrimalRaw)
member t.Shape = t.PrimalRaw.Shape
member t.Dim = t.PrimalRaw.Dim
member t.Nelement = t.PrimalRaw.Nelement
member t.ToArray() = t.PrimalRaw.ToArray()
member t.ToValue() = t.PrimalRaw.ToValue()
member t.Zero() = Tensor(t.PrimalRaw.Zero())
member t.Create(value) = Tensor(t.PrimalRaw.Create(value))
override t.Equals(other) =
match other with
| :? Tensor as tensor -> t.PrimalRaw.Equals(tensor.PrimalRaw)
| _ -> false
member t.ApproximatelyEqual(tensor:Tensor, ?tolerance) =
let tolerance = defaultArg tolerance 0.01
t.PrimalRaw.ApproximatelyEquals(tensor.PrimalRaw, tolerance)
override t.GetHashCode() =
match t with
| Tensor(tp) -> hash (tp)
| TensorF(tp,td,tt) -> hash (tp, td, tt)
| TensorR(tp,td,_,_,tt) -> hash (tp, !td, tt)
interface System.IComparable with
override t.CompareTo(other) =
match other with
| :? Tensor as tensor ->
if t.Dim = tensor.Dim && t.Dim = 0 then
t.PrimalRaw.CompareTo(tensor.PrimalRaw)
else
failwith "Cannot compare non-scalar Tensors"
| _ -> failwith "Cannot compare Tensor with another type"
member t1.IsSameDiffType(t2:Tensor) =
match t1, t2 with
| Tensor(_), Tensor(_) -> true
| Tensor(_), TensorF(_) -> false
| Tensor(_), TensorR(_) -> false
| TensorF(_), Tensor(_) -> false
| TensorF(_), TensorF(_) -> true
| TensorF(_), TensorR(_) -> false
| TensorR(_), Tensor(_) -> false
| TensorR(_), TensorF(_) -> false
| TensorR(_), TensorR(_) -> true
static member op_Explicit(tensor:Tensor):'a = downcast tensor.PrimalRaw.ToValue()
static member ZerosLike(tensor:Tensor) = Tensor(tensor.PrimalRaw.Zeros(tensor.Shape))
static member ZerosLike(tensor:Tensor, shape:seq<int>) = Tensor(tensor.PrimalRaw.Zeros(shape |> Array.ofSeq))
static member OnesLike(tensor:Tensor) = Tensor(tensor.PrimalRaw.Ones(tensor.Shape))
static member OnesLike(tensor:Tensor, shape:seq<int>) = Tensor(tensor.PrimalRaw.Ones(shape |> Array.ofSeq))
static member Zeros(shape:seq<int>, ?dtype:DType, ?device:Device, ?backend:Backend) =
Tensor(RawTensor.Zeros(shape|>Seq.toArray, ?dtype=dtype, ?device=device, ?backend=backend))
static member Ones(shape:seq<int>, ?dtype:DType, ?device:Device, ?backend:Backend) =
Tensor(RawTensor.Ones(shape|>Seq.toArray, ?dtype=dtype, ?device=device, ?backend=backend))
static member Create(value:obj, ?dtype:DType, ?device:Device, ?backend:Backend) =
Tensor(RawTensor.Create(value, ?dtype=dtype, ?device=device, ?backend=backend))
static member inline OpUnary(a, fRaw, fTensor, dfTensorFwd, dfTensorRev) =
match a with
| Tensor(ap) -> Tensor(fRaw(ap))
| TensorF(ap,ad,at) -> let cp = fTensor(ap) in TensorF(cp, dfTensorFwd(cp,ap,ad), at)
| TensorR(ap,_,_,_,at) -> let cp = fTensor(ap) in TensorR(cp, ref (a.Zero()), dfTensorRev(a), ref 0u, at)
static member inline OpBinary(a, b, fRaw, fTensor, dfTensorFwdTT, dfTensorFwdTC, dfTensorFwdCT, dfTensorRevTT, dfTensorRevTC, dfTensorRevCT) =
match a, b with
| Tensor(ap), Tensor(bp) -> Tensor(fRaw(ap, bp))
| Tensor(_), TensorF(bp,bd,bt) -> let cp = fTensor(a,bp) in TensorF(cp, dfTensorFwdCT(cp,bp,bd), bt)
| Tensor(_), TensorR(bp,_,_,_,bt) -> let cp = fTensor(a,bp) in TensorR(cp, ref (a.Zero()), dfTensorRevCT(a,b), ref 0u, bt)
| TensorF(ap,ad,at), Tensor(_) -> let cp = fTensor(ap,b) in TensorF(cp, dfTensorFwdTC(cp,ap,ad), at)
| TensorF(ap,ad,at), TensorF(bp,bd,bt) when at=bt -> let cp = fTensor(ap,bp) in TensorF(cp, dfTensorFwdTT(cp,ap,ad,bp,bd), at)
| TensorF(ap,ad,at), TensorF(_,_,bt) when at>bt -> let cp = fTensor(ap,b) in TensorF(cp, dfTensorFwdTC(cp,ap,ad), at)
| TensorF(_,_,at), TensorF(bp,bd,bt) when at<bt -> let cp = fTensor(a,bp) in TensorF(cp, dfTensorFwdCT(cp,bp,bd), bt)
| TensorF(_,_,at), TensorR(_,_,_,_,bt) when at=bt -> failwith "Cannot have TensorF and TensorR in the same nesting level"
| TensorF(ap,ad,at), TensorR(_,_,_,_,bt) when at>bt -> let cp = fTensor(ap,b) in TensorF(cp, dfTensorFwdTC(cp,ap,ad), at)
| TensorF(_,_,at), TensorR(bp,_,_,_,bt) when at<bt -> let cp = fTensor(a,bp) in TensorR(cp, ref (a.Zero()), dfTensorRevCT(a,b), ref 0u, bt)
| TensorR(ap,_,_,_,at), Tensor(_) -> let cp = fTensor(ap,b) in TensorR(cp, ref (a.Zero()), dfTensorRevTC(a,b), ref 0u, at)
| TensorR(_,_,_,_,at), TensorF(_,_,bt) when at=bt -> failwith "Cannot have TensorR and TensorF in the same nesting level"
| TensorR(ap,_,_,_,at), TensorF(_,_,bt) when at>bt -> let cp = fTensor(ap, b) in TensorR(cp, ref (a.Zero()), dfTensorRevTC(a,b), ref 0u, at)
| TensorR(_,_,_,_,at), TensorF(bp,bd,bt) when at<bt -> let cp = fTensor(a,bp) in TensorF(cp, dfTensorFwdCT(cp, bp, bd), bt)
| TensorR(ap,_,_,_,at), TensorR(bp,_,_,_,bt) when at=bt -> let cp = fTensor(ap,bp) in TensorR(cp, ref (a.Zero()), dfTensorRevTT(a,b), ref 0u, at)
| TensorR(ap,_,_,_,at), TensorR(_,_,_,_,bt) when at>bt -> let cp = fTensor(ap,b) in TensorR(cp, ref (a.Zero()), dfTensorRevTC(a,b), ref 0u, at)
| TensorR(_,_,_,_,at), TensorR(bp,_,_,_,bt) when at<bt -> let cp = fTensor(a,bp) in TensorR(cp, ref (a.Zero()), dfTensorRevCT(a,b), ref 0u, bt)
| _ -> failwith "Unexpected combination of Tensors" // Won't happen, added for suppressing "incomplete matches" warning
member t.Reverse(?value:Tensor, ?zeroDerivatives:bool) =
let value = defaultArg value (Tensor.OnesLike(t))
let zeroDerivatives = defaultArg zeroDerivatives true
if value.Shape <> t.Shape then failwithf "Expecting an adjoint value of shape %A, but received of shape %A" t.Shape value.Shape
t.ReverseReset(zeroDerivatives)
t.ReversePush(value)
member inline t.Backward(value) = t.Reverse(value)
member t.ReverseReset(zeroDerivatives:bool) =
let rec reset (ts: Tensor list) =
match ts with
| [] -> ()
| t :: tt ->
match t with
| TensorR(_,_,o,_,_) ->
if zeroDerivatives then t.Derivative <- t.Zero()
t.Fanout <- t.Fanout + 1u
if t.Fanout = 1u then
match o with
| NewT -> reset tt
| OpExtensionT(inps, _) -> reset (inps@tt)
else reset tt
| _ -> reset tt
reset [t]
member t.ReversePush(value:Tensor) =
let rec push (ts:(Tensor*Tensor) list) =
match ts with
| [] -> ()
| (v, t) :: tt ->
match t with
| TensorR(_,_,o,_,_) ->
t.Derivative <- t.Derivative + v
t.Fanout <- t.Fanout - 1u
if t.Fanout = 0u then
match o with
| NewT -> push tt
| OpExtensionT(inps, backpropf) -> push (List.zip (backpropf t) inps @ tt)
else push tt
| _ -> push tt
push [(value, t)]
static member Extension(ext: UnaryExtension) =
(fun a ->
Tensor.OpUnary(a, ext.Compute, Tensor.Extension ext, ext.GradForward,
(fun a -> OpExtensionT([a], (fun t -> [ext.GradReverse (t,a)])))
))
static member Extension(ext: BinaryExtension) =
(fun (a, b) ->
Tensor.OpBinary(a, b, ext.Compute, Tensor.Extension ext,
ext.GradForwardTT,
(fun (cp,a,ad) -> ext.GradForwardTC(cp,a,ad,b)),
(fun (cp,b,bd) -> ext.GradForwardCT(cp,a,b,bd)),
(fun (a,b) -> OpExtensionT([a;b], (fun t -> let ra, rb = ext.GradReverseTT (t,a,b) in [ra; rb]))),
(fun (a,b) -> OpExtensionT([a;b], (fun t -> let ra = ext.GradReverseTC (t,a,b) in [ra]))),
(fun (a,b) -> OpExtensionT([a;b], (fun t -> let rb = ext.GradReverseCT (t,a,b) in [rb])))
))
and TensorOp =
| NewT
| OpExtensionT of inputs: Tensor list * backpropf: ((* tangent: *) Tensor -> (* revgrad: *) Tensor list)
/// Defines an extension implementing a unary function and its gradients
and UnaryExtension =
/// Compute the function f(a)
abstract Compute: a: RawTensor -> RawTensor
/// Compute the forward gradient of function.
abstract GradForward: fa: Tensor * a: Tensor * da: Tensor -> Tensor
/// Compute the reverse gradient (adjoint) of function.
abstract GradReverse: t: Tensor * a: Tensor -> Tensor
/// Defines an extension implementing a binary function and its gradients
and BinaryExtension =
/// Compute the function on raw tensors
abstract Compute: a: RawTensor * b: RawTensor -> RawTensor
/// Compute the forward gradient of function.
abstract GradForwardTT: fab: Tensor * a: Tensor * da: Tensor * b: Tensor * db: Tensor -> Tensor
abstract GradForwardTC: fab: Tensor * a: Tensor * da: Tensor * b: Tensor -> Tensor
abstract GradForwardCT: fab: Tensor * a: Tensor * b: Tensor * db: Tensor -> Tensor
/// Compute the reverse gradient (adjoint) of function.
abstract GradReverseTT: t: Tensor * a: Tensor * b: Tensor -> Tensor * Tensor
abstract GradReverseTC: t: Tensor * a: Tensor * b: Tensor -> Tensor
abstract GradReverseCT: t: Tensor * a: Tensor * b: Tensor -> Tensor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment