[<ShapeCheck>]
type NeuralStyles(C: Int, H: Int, W: Int) =
inherit Model()
let instance_norm (channels: Int, h, w) =
model {
let! shift = Weight.uniform (Shape [| channels; h; w |], 0.0)
and! scale = Weight.uniform (Shape [| channels; h; w |], 1.0)
return (fun input ->
let mu, sigma_sq = input.moments(dims= [2;3])
let epsilon = 0.001
let normalized = (input - mu) / sqrt (sigma_sq + epsilon)
scale * normalized + shift)
}
let conv_layer (in_channels: Int, out_channels: Int, filter_size: Int, stride: Int, out_height, out_width) =
model {
let! filters = Weight.uniform (Shape [| out_channels; in_channels; filter_size; filter_size|], 0.1)
and! inorm = instance_norm (out_channels, out_height, out_width)
return (fun input ->
dsharp.conv2d (input, filters, stride=stride, padding=filter_size/2)
|> inorm)
}
let conv_transpose_layer (in_channels: Int, out_channels:Int, filter_size:Int, stride, out_height, out_width) =
model {
let! filters = Weight.uniform (Shape [| in_channels; out_channels; filter_size; filter_size|], 0.1)
and! inorm = instance_norm (out_channels, out_height, out_width)
return (fun input ->
dsharp.convTranspose2d(input, filters, stride=stride, padding=filter_size/stride, outputPadding=filter_size % stride)
|> inorm)
}
let residual_block (filter_size, name, height, width) =
model {
let! conv1 = conv_layer (128I, 128I, filter_size, 1I, height, width)
and! conv2 = conv_layer (128I, 128I, filter_size, 1I, height, width)
return (fun input -> input + conv1 input |> dsharp.relu |> conv2)
}
let to_pixel_value (input: Tensor) =
dsharp.tanh input * 150.0 + (255.0 / 2.0)
let clip min max input =
dsharp.clamp(input, min, max)
let model : Model =
conv_layer (C, 32I, 9I, 1I, H, W) --> dsharp.relu
--> conv_layer (32I, 64I, 3I, 2I, H/2, W/2) --> dsharp.relu
--> conv_layer (64I, 128I, 3I, 2I, H/4, W/4) --> dsharp.relu
--> residual_block (3I, "resid1", H/4, W/4)
--> residual_block (3I, "resid2", H/4, W/4)
--> residual_block (3I, "resid3", H/4, W/4)
--> residual_block (3I, "resid4", H/4, W/4)
--> residual_block (3I, "resid5", H/4, W/4)
--> conv_transpose_layer (128I, 64I, 3I, 2I, H/2, W/2) --> dsharp.relu
--> conv_transpose_layer (64I, 32I, 3I, 2I, H, W) --> dsharp.relu
--> conv_layer (32I, C, 9I, 1I, H, W)
--> to_pixel_value
--> clip 0.0 255.0
override _.forward(input) =
model.forward(input)
The builder looks like this:
type ModelBuilder() =
let ps = ResizeArray<Choice<Parameter, Model>>()
member x.Run(f) = Model.create (ps |> Seq.map (function Choice1Of2 x -> box x | Choice2Of2 x -> box x)) f
member x.Source(p: Parameter) = ps.Add (Choice1Of2 p); (fun () -> p.value)
member x.Source(m: Model) = ps.Add (Choice2Of2 m); (fun () -> m.forward)
member x.Source(t: Tensor) = x.Source(Parameter(t))
member x.MergeSources(t1: (unit -> 'T1), t2: (unit -> 'T2)) = (t1, t2)
member x.BindReturn((v1, v2): (unit -> 'T1) * (unit -> 'T2), k: 'T1 * 'T2 -> Tensor -> Tensor) =
(fun input -> k (v1(), v2()) input)