Last active
February 27, 2020 14:50
-
-
Save dsyme/924686f34bb103344213f8f5f4bbc721 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
namespace DiffSharp | |
open DiffSharp.Backend | |
open DiffSharp.Util | |
[<CustomEquality; CustomComparison>] | |
type Tensor = | |
| Tensor of primalRaw:RawTensor | |
| TensorF of primal:Tensor * derivative:Tensor * nestingTag:uint32 | |
| TensorR of primal:Tensor * derivative:(Tensor ref) * parentOperation:TensorOp * fanout:(uint32 ref) * nestingTag:uint32 | |
member t.Primal = | |
match t with | |
| Tensor(_) -> t | |
| TensorF(tp,_,_) -> tp | |
| TensorR(tp,_,_,_,_) -> tp | |
member t.PrimalRaw = | |
match t with | |
| Tensor(tp) -> tp | |
| TensorF(tp,_,_) -> tp.PrimalRaw | |
| TensorR(tp,_,_,_,_) -> tp.PrimalRaw | |
member t.Depth = | |
let rec depth x d = | |
match x with | |
| Tensor(_) -> d | |
| TensorF(tp,_,_) -> depth tp (d + 1) | |
| TensorR(tp,_,_,_,_) -> depth tp (d + 1) | |
depth t 0 | |
member t.Derivative | |
with get() = | |
match t with | |
| Tensor(_) -> failwith "Cannot get derivative of constant Tensor" | |
| TensorF(_,td,_) -> td | |
| TensorR(_,td,_,_,_) -> !td | |
and set(value) = | |
match t with | |
| Tensor(_) -> failwith "Cannot set derivative of constant Tensor" | |
| TensorF(_) -> failwith "Cannot set derivative of TensorF" | |
| TensorR(_,td,_,_,_) -> td := value | |
member t.Fanout | |
with get() = | |
match t with | |
| Tensor(_) -> failwith "Cannot get fanout of constant Tensor" | |
| TensorF(_) -> failwith "Cannot get fanout of TensorF" | |
| TensorR(_,_,_,f,_) -> !f | |
and set(value) = | |
match t with | |
| Tensor(_) -> failwith "Cannot set fanout of constant Tensor" | |
| TensorF(_) -> failwith "Cannot set fanout of TensorF" | |
| TensorR(_,_,_,f,_) -> f := value | |
member t.ForwardDiff(derivative:Tensor, ?tag:uint32) = | |
let tag = defaultArg tag GlobalNestingLevel.Current | |
if t.Shape = derivative.Shape then TensorF(t, derivative, tag) else failwithf "Expecting derivative of same shape with primal. primal: %A, derivative: %A" t derivative | |
member t.ReverseDiff(?tag:uint32) = | |
let tag = defaultArg tag GlobalNestingLevel.Current | |
TensorR(t, ref (t.Zero()), NewT, ref 0u, tag) | |
member t.NoDiff() = Tensor(t.PrimalRaw) | |
member t.Shape = t.PrimalRaw.Shape | |
member t.Dim = t.PrimalRaw.Dim | |
member t.Nelement = t.PrimalRaw.Nelement | |
member t.ToArray() = t.PrimalRaw.ToArray() | |
member t.ToValue() = t.PrimalRaw.ToValue() | |
member t.Zero() = Tensor(t.PrimalRaw.Zero()) | |
member t.Create(value) = Tensor(t.PrimalRaw.Create(value)) | |
override t.Equals(other) = | |
match other with | |
| :? Tensor as tensor -> t.PrimalRaw.Equals(tensor.PrimalRaw) | |
| _ -> false | |
member t.ApproximatelyEqual(tensor:Tensor, ?tolerance) = | |
let tolerance = defaultArg tolerance 0.01 | |
t.PrimalRaw.ApproximatelyEquals(tensor.PrimalRaw, tolerance) | |
override t.GetHashCode() = | |
match t with | |
| Tensor(tp) -> hash (tp) | |
| TensorF(tp,td,tt) -> hash (tp, td, tt) | |
| TensorR(tp,td,_,_,tt) -> hash (tp, !td, tt) | |
interface System.IComparable with | |
override t.CompareTo(other) = | |
match other with | |
| :? Tensor as tensor -> | |
if t.Dim = tensor.Dim && t.Dim = 0 then | |
t.PrimalRaw.CompareTo(tensor.PrimalRaw) | |
else | |
failwith "Cannot compare non-scalar Tensors" | |
| _ -> failwith "Cannot compare Tensor with another type" | |
member t1.IsSameDiffType(t2:Tensor) = | |
match t1, t2 with | |
| Tensor(_), Tensor(_) -> true | |
| Tensor(_), TensorF(_) -> false | |
| Tensor(_), TensorR(_) -> false | |
| TensorF(_), Tensor(_) -> false | |
| TensorF(_), TensorF(_) -> true | |
| TensorF(_), TensorR(_) -> false | |
| TensorR(_), Tensor(_) -> false | |
| TensorR(_), TensorF(_) -> false | |
| TensorR(_), TensorR(_) -> true | |
static member op_Explicit(tensor:Tensor):'a = downcast tensor.PrimalRaw.ToValue() | |
static member ZerosLike(tensor:Tensor) = Tensor(tensor.PrimalRaw.Zeros(tensor.Shape)) | |
static member ZerosLike(tensor:Tensor, shape:seq<int>) = Tensor(tensor.PrimalRaw.Zeros(shape |> Array.ofSeq)) | |
static member OnesLike(tensor:Tensor) = Tensor(tensor.PrimalRaw.Ones(tensor.Shape)) | |
static member OnesLike(tensor:Tensor, shape:seq<int>) = Tensor(tensor.PrimalRaw.Ones(shape |> Array.ofSeq)) | |
static member Zeros(shape:seq<int>, ?dtype:DType, ?device:Device, ?backend:Backend) = | |
Tensor(RawTensor.Zeros(shape|>Seq.toArray, ?dtype=dtype, ?device=device, ?backend=backend)) | |
static member Ones(shape:seq<int>, ?dtype:DType, ?device:Device, ?backend:Backend) = | |
Tensor(RawTensor.Ones(shape|>Seq.toArray, ?dtype=dtype, ?device=device, ?backend=backend)) | |
static member Create(value:obj, ?dtype:DType, ?device:Device, ?backend:Backend) = | |
Tensor(RawTensor.Create(value, ?dtype=dtype, ?device=device, ?backend=backend)) | |
static member inline OpUnary(a, fRaw, fTensor, dfTensorFwd, dfTensorRev) = | |
match a with | |
| Tensor(ap) -> Tensor(fRaw(ap)) | |
| TensorF(ap,ad,at) -> let cp = fTensor(ap) in TensorF(cp, dfTensorFwd(cp,ap,ad), at) | |
| TensorR(ap,_,_,_,at) -> let cp = fTensor(ap) in TensorR(cp, ref (a.Zero()), dfTensorRev(a), ref 0u, at) | |
static member inline OpBinary(a, b, fRaw, fTensor, dfTensorFwdTT, dfTensorFwdTC, dfTensorFwdCT, dfTensorRevTT, dfTensorRevTC, dfTensorRevCT) = | |
match a, b with | |
| Tensor(ap), Tensor(bp) -> Tensor(fRaw(ap, bp)) | |
| Tensor(_), TensorF(bp,bd,bt) -> let cp = fTensor(a,bp) in TensorF(cp, dfTensorFwdCT(cp,bp,bd), bt) | |
| Tensor(_), TensorR(bp,_,_,_,bt) -> let cp = fTensor(a,bp) in TensorR(cp, ref (a.Zero()), dfTensorRevCT(a,b), ref 0u, bt) | |
| TensorF(ap,ad,at), Tensor(_) -> let cp = fTensor(ap,b) in TensorF(cp, dfTensorFwdTC(cp,ap,ad), at) | |
| TensorF(ap,ad,at), TensorF(bp,bd,bt) when at=bt -> let cp = fTensor(ap,bp) in TensorF(cp, dfTensorFwdTT(cp,ap,ad,bp,bd), at) | |
| TensorF(ap,ad,at), TensorF(_,_,bt) when at>bt -> let cp = fTensor(ap,b) in TensorF(cp, dfTensorFwdTC(cp,ap,ad), at) | |
| TensorF(_,_,at), TensorF(bp,bd,bt) when at<bt -> let cp = fTensor(a,bp) in TensorF(cp, dfTensorFwdCT(cp,bp,bd), bt) | |
| TensorF(_,_,at), TensorR(_,_,_,_,bt) when at=bt -> failwith "Cannot have TensorF and TensorR in the same nesting level" | |
| TensorF(ap,ad,at), TensorR(_,_,_,_,bt) when at>bt -> let cp = fTensor(ap,b) in TensorF(cp, dfTensorFwdTC(cp,ap,ad), at) | |
| TensorF(_,_,at), TensorR(bp,_,_,_,bt) when at<bt -> let cp = fTensor(a,bp) in TensorR(cp, ref (a.Zero()), dfTensorRevCT(a,b), ref 0u, bt) | |
| TensorR(ap,_,_,_,at), Tensor(_) -> let cp = fTensor(ap,b) in TensorR(cp, ref (a.Zero()), dfTensorRevTC(a,b), ref 0u, at) | |
| TensorR(_,_,_,_,at), TensorF(_,_,bt) when at=bt -> failwith "Cannot have TensorR and TensorF in the same nesting level" | |
| TensorR(ap,_,_,_,at), TensorF(_,_,bt) when at>bt -> let cp = fTensor(ap, b) in TensorR(cp, ref (a.Zero()), dfTensorRevTC(a,b), ref 0u, at) | |
| TensorR(_,_,_,_,at), TensorF(bp,bd,bt) when at<bt -> let cp = fTensor(a,bp) in TensorF(cp, dfTensorFwdCT(cp, bp, bd), bt) | |
| TensorR(ap,_,_,_,at), TensorR(bp,_,_,_,bt) when at=bt -> let cp = fTensor(ap,bp) in TensorR(cp, ref (a.Zero()), dfTensorRevTT(a,b), ref 0u, at) | |
| TensorR(ap,_,_,_,at), TensorR(_,_,_,_,bt) when at>bt -> let cp = fTensor(ap,b) in TensorR(cp, ref (a.Zero()), dfTensorRevTC(a,b), ref 0u, at) | |
| TensorR(_,_,_,_,at), TensorR(bp,_,_,_,bt) when at<bt -> let cp = fTensor(a,bp) in TensorR(cp, ref (a.Zero()), dfTensorRevCT(a,b), ref 0u, bt) | |
| _ -> failwith "Unexpected combination of Tensors" // Won't happen, added for suppressing "incomplete matches" warning | |
member t.Reverse(?value:Tensor, ?zeroDerivatives:bool) = | |
let value = defaultArg value (Tensor.OnesLike(t)) | |
let zeroDerivatives = defaultArg zeroDerivatives true | |
if value.Shape <> t.Shape then failwithf "Expecting an adjoint value of shape %A, but received of shape %A" t.Shape value.Shape | |
t.ReverseReset(zeroDerivatives) | |
t.ReversePush(value) | |
member inline t.Backward(value) = t.Reverse(value) | |
member t.ReverseReset(zeroDerivatives:bool) = | |
let rec reset (ts: Tensor list) = | |
match ts with | |
| [] -> () | |
| t :: tt -> | |
match t with | |
| TensorR(_,_,o,_,_) -> | |
if zeroDerivatives then t.Derivative <- t.Zero() | |
t.Fanout <- t.Fanout + 1u | |
if t.Fanout = 1u then | |
match o with | |
| NewT -> reset tt | |
| OpExtensionT(inps, _) -> reset (inps@tt) | |
else reset tt | |
| _ -> reset tt | |
reset [t] | |
member t.ReversePush(value:Tensor) = | |
let rec push (ts:(Tensor*Tensor) list) = | |
match ts with | |
| [] -> () | |
| (v, t) :: tt -> | |
match t with | |
| TensorR(_,_,o,_,_) -> | |
t.Derivative <- t.Derivative + v | |
t.Fanout <- t.Fanout - 1u | |
if t.Fanout = 0u then | |
match o with | |
| NewT -> push tt | |
| OpExtensionT(inps, backpropf) -> push (List.zip (backpropf t) inps @ tt) | |
else push tt | |
| _ -> push tt | |
push [(value, t)] | |
static member Extension(ext: UnaryExtension) = | |
(fun a -> | |
Tensor.OpUnary(a, ext.Compute, Tensor.Extension ext, ext.GradForward, | |
(fun a -> OpExtensionT([a], (fun t -> [ext.GradReverse (t,a)]))) | |
)) | |
static member Extension(ext: BinaryExtension) = | |
(fun (a, b) -> | |
Tensor.OpBinary(a, b, ext.Compute, Tensor.Extension ext, | |
ext.GradForwardTT, | |
(fun (cp,a,ad) -> ext.GradForwardTC(cp,a,ad,b)), | |
(fun (cp,b,bd) -> ext.GradForwardCT(cp,a,b,bd)), | |
(fun (a,b) -> OpExtensionT([a;b], (fun t -> let ra, rb = ext.GradReverseTT (t,a,b) in [ra; rb]))), | |
(fun (a,b) -> OpExtensionT([a;b], (fun t -> let ra = ext.GradReverseTC (t,a,b) in [ra]))), | |
(fun (a,b) -> OpExtensionT([a;b], (fun t -> let rb = ext.GradReverseCT (t,a,b) in [rb]))) | |
)) | |
and TensorOp = | |
| NewT | |
| OpExtensionT of inputs: Tensor list * backpropf: ((* tangent: *) Tensor -> (* revgrad: *) Tensor list) | |
/// Defines an extension implementing a unary function and its gradients | |
and UnaryExtension = | |
/// Compute the function f(a) | |
abstract Compute: a: RawTensor -> RawTensor | |
/// Compute the forward gradient of function. | |
abstract GradForward: fa: Tensor * a: Tensor * da: Tensor -> Tensor | |
/// Compute the reverse gradient (adjoint) of function. | |
abstract GradReverse: t: Tensor * a: Tensor -> Tensor | |
/// Defines an extension implementing a binary function and its gradients | |
and BinaryExtension = | |
/// Compute the function on raw tensors | |
abstract Compute: a: RawTensor * b: RawTensor -> RawTensor | |
/// Compute the forward gradient of function. | |
abstract GradForwardTT: fab: Tensor * a: Tensor * da: Tensor * b: Tensor * db: Tensor -> Tensor | |
abstract GradForwardTC: fab: Tensor * a: Tensor * da: Tensor * b: Tensor -> Tensor | |
abstract GradForwardCT: fab: Tensor * a: Tensor * b: Tensor * db: Tensor -> Tensor | |
/// Compute the reverse gradient (adjoint) of function. | |
abstract GradReverseTT: t: Tensor * a: Tensor * b: Tensor -> Tensor * Tensor | |
abstract GradReverseTC: t: Tensor * a: Tensor * b: Tensor -> Tensor | |
abstract GradReverseCT: t: Tensor * a: Tensor * b: Tensor -> Tensor | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment