Skip to content

Instantly share code, notes, and snippets.

@tanmayb123
Created February 28, 2019 22:02
Show Gist options
  • Save tanmayb123/0fd2a935dacb9e012174035ba4aed25f to your computer and use it in GitHub Desktop.
Save tanmayb123/0fd2a935dacb9e012174035ba4aed25f to your computer and use it in GitHub Desktop.
@_fixed_layout
public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
public var weight: Tensor<Scalar>
public var bias: Tensor<Scalar>
@noDerivative var useBias: Bool
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>
@noDerivative public let activation: Activation
public init(
weight: Tensor<Scalar>,
bias: Tensor<Scalar>,
useBias: Bool,
activation: @escaping Activation
) {
self.weight = weight
self.bias = bias
self.useBias = useBias
self.activation = activation
}
@differentiable
private func applyingWithBias(to input: Tensor<Scalar>) -> Tensor<Scalar> {
return activation(matmul(input, weight) + bias)
}
@differentiable
private func applyingWithoutBias(to input: Tensor<Scalar>) -> Tensor<Scalar> {
return activation(matmul(input, weight))
}
@differentiable(vjp: _vjpApplied(to:in:))
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
if useBias {
return applyingWithBias(to: input)
} else {
return applyingWithoutBias(to: input)
}
}
@usableFromInline
func _vjpApplied(to input: Tensor<Scalar>, in context: Context) ->
(Tensor<Scalar>, (Tensor<Scalar>) ->
(Dense<Scalar>.CotangentVector, Tensor<Scalar>)) {
if useBias {
return valueWithPullback(at: input) {
$0.applyingWithBias(to: $1)
}
} else {
return valueWithPullback(at: input) {
$0.applyingWithoutBias(to: $1)
}
}
}
}
public extension Dense where Scalar.RawSignificand: FixedWidthInteger {
init<G: RandomNumberGenerator>(
inputSize: Int,
outputSize: Int,
activation: @escaping Activation = identity,
generator: inout G,
useBias: Bool
) {
self.init(weight: Tensor<Scalar>(glorotUniform: [Int32(inputSize), Int32(outputSize)],
generator: &generator),
bias: useBias ? Tensor(zeros: [Int32(outputSize)]) : Tensor(0),
useBias: useBias, activation: activation)
}
init(inputSize: Int, outputSize: Int, activation: @escaping Activation = identity, useBias: Bool) {
self.init(inputSize: inputSize, outputSize: outputSize, activation: activation,
generator: &PhiloxRandomNumberGenerator.global, useBias: useBias)
}
}
public extension Dense {
init(
inputSize: Int,
outputSize: Int,
activation: @escaping Activation = identity,
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
Int64.random(in: Int64.min..<Int64.max)),
useBias: Bool
) {
self.init(weight: Tensor(glorotUniform: [Int32(inputSize), Int32(outputSize)],
seed: seed),
bias: useBias ? Tensor(zeros: [Int32(outputSize)]) : Tensor(0),
useBias: useBias, activation: activation)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment