Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
[<ShapeCheck>]
type NeuralStyles(C: Int, H: Int, W: Int) =
    inherit Model()

    let instance_norm (channels: Int, h, w) = 
        model { 
            let! shift = Weight.uniform (Shape [| channels; h; w |], 0.0)
            and! scale = Weight.uniform (Shape [| channels; h; w |], 1.0)
            return (fun input  ->
                let mu, sigma_sq = input.moments(dims= [2;3]) 
                let epsilon = 0.001
                let normalized =  (input - mu) / sqrt (sigma_sq + epsilon)
                scale * normalized + shift) 
        }

    let conv_layer (in_channels: Int, out_channels: Int, filter_size: Int, stride: Int, out_height, out_width) = 
        model {
            let! filters = Weight.uniform (Shape [| out_channels; in_channels; filter_size; filter_size|], 0.1)
            and! inorm = instance_norm (out_channels, out_height, out_width) 
            return (fun input  ->
                dsharp.conv2d (input, filters, stride=stride, padding=filter_size/2)
                |> inorm)
        }

    let conv_transpose_layer (in_channels: Int, out_channels:Int, filter_size:Int, stride, out_height, out_width) =
        model { 
            let! filters = Weight.uniform (Shape [| in_channels; out_channels; filter_size; filter_size|], 0.1)
            and! inorm = instance_norm (out_channels, out_height, out_width) 
            return (fun input -> 
                dsharp.convTranspose2d(input, filters, stride=stride, padding=filter_size/stride, outputPadding=filter_size % stride) 
                |> inorm)
        }

    let residual_block (filter_size, name, height, width) = 
        model { 
            let! conv1 = conv_layer (128I, 128I, filter_size, 1I, height, width)
            and! conv2 = conv_layer (128I, 128I, filter_size, 1I, height, width) 
            return (fun input  -> input + conv1 input |> dsharp.relu |> conv2)
        }

    let to_pixel_value (input: Tensor) = 
        dsharp.tanh input * 150.0 + (255.0 / 2.0)
        
    let clip min max input = 
        dsharp.clamp(input, min, max)

    let model : Model =
        conv_layer (C, 32I, 9I, 1I, H, W) --> dsharp.relu
        --> conv_layer (32I, 64I, 3I, 2I, H/2, W/2) --> dsharp.relu
        --> conv_layer (64I, 128I, 3I, 2I, H/4, W/4) --> dsharp.relu
        --> residual_block (3I, "resid1", H/4, W/4)
        --> residual_block (3I, "resid2", H/4, W/4)
        --> residual_block (3I, "resid3", H/4, W/4)
        --> residual_block (3I, "resid4", H/4, W/4)
        --> residual_block (3I, "resid5", H/4, W/4)
        --> conv_transpose_layer (128I, 64I, 3I, 2I, H/2, W/2) --> dsharp.relu
        --> conv_transpose_layer (64I, 32I, 3I, 2I, H, W) --> dsharp.relu
        --> conv_layer (32I, C, 9I, 1I, H, W)
        --> to_pixel_value 
        --> clip 0.0 255.0

    override _.forward(input) = 
        model.forward(input) 
        

The builder looks like this:

type ModelBuilder() =
    let ps = ResizeArray<Choice<Parameter, Model>>()

    member x.Run(f) = Model.create (ps |> Seq.map (function Choice1Of2 x -> box x | Choice2Of2 x -> box x)) f
    member x.Source(p: Parameter) = ps.Add (Choice1Of2 p); (fun () -> p.value)
    member x.Source(m: Model) = ps.Add (Choice2Of2 m); (fun () -> m.forward)
    member x.Source(t: Tensor) = x.Source(Parameter(t))
    member x.MergeSources(t1: (unit -> 'T1), t2: (unit -> 'T2)) = (t1, t2)
    member x.BindReturn((v1, v2): (unit -> 'T1) * (unit -> 'T2),  k: 'T1 * 'T2 -> Tensor -> Tensor) =
        (fun input -> k (v1(), v2()) input)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.