Skip to content

Instantly share code, notes, and snippets.

@dsyme
Last active February 15, 2022 15:22
Show Gist options
  • Save dsyme/ccb5d1e4ec4ba32405709fe9b2cda152 to your computer and use it in GitHub Desktop.
Save dsyme/ccb5d1e4ec4ba32405709fe9b2cda152 to your computer and use it in GitHub Desktop.
#!/usr/bin/env -S dotnet fsi
#r "nuget: DiffSharp-cuda, 1.0.6"
#r @"..\bin\Debug\net6.0\fwdsgd.dll"
#load "argparse.fsx"
open System.IO
open DiffSharp
open DiffSharp.Model
open DiffSharp.Util
open DiffSharp.Data
open Helpers
let parser = ArgumentParser()
parser.add_argument("--dir")
parser.add_argument("--lr", dflt="0.0002")
parser.add_argument("--lr_decay", dflt="0.0001")
parser.add_argument("--threshold", dflt="0.00001")
parser.add_argument("--n", dflt="25000")
parser.add_argument("--runs", dflt="1")
parser.add_argument("--batch_size", dflt="128")
parser.add_argument("--num_workers", dflt="8")
parser.add_argument("--valid_every", dflt="500")
parser.add_argument("--gc_every", help="GC every N iterations. Default is determined by model. Note DiffSharp needs regular GC for large models.")
parser.add_argument("--device", choices=["cpu";"cuda"], dflt="cpu")
parser.add_argument("--model", choices=["logreg"; "mlp"; "cnn"; "cnn2"; "cnn4"; "cnn4b"; "vgg16"; "resnet18"; "resnet50"], dflt="logreg")
parser.add_argument("--optimizer", choices=["sgd"; "sgdn"; "adam"], dflt="sgd")
parser.add_argument("--momentum", dflt="0.2")
parser.add_argument0("--skiprev")
parser.add_argument0("--skipfwd")
parser.parse_args()
let optimizer = parser.result("--optimizer")
let modelArg = parser.result("--model")
let dirArg = parser.result("--dir")
let lr = parser.resultFloat("--lr")
let lr_decay = parser.resultFloat("--lr_decay")
let n = parser.result("--n") |> int
let runs = parser.result("--runs") |> int
let batch_size = parser.resultInt("--batch_size")
let device = parser.result("--device")
let skiprev = parser.resultBool("--skiprev")
let skipfwd = parser.resultBool("--skipfwd")
let valid_every = parser.resultInt("--valid_every")
let momentum = parser.resultFloat("--momentum")
let threshold = parser.resultFloat("--threshold")
let gc_every_arg = parser.resultIntOption("--gc_every")
dsharp.config(backend=Backend.Torch, device= (if device = "cpu" then Device.CPU else Device.GPU))
dsharp.seed(0)
let logreg_ps() =
[ "w1", Weight.kaiming(28*28, 10)
"b1", Weight.bias(10)]
|> parameters
let logreg_eval(ps: ParameterDict, x: Tensor) =
let x = x.view [-1; 28*28 ]
let x = x.matmul(ps["w1"])
let x = x + ps["b1"]
// x = x.matmul(ps["w2"])
// x = x + ps["b2"]
x
let logreg_loss(ps, x, target) =
let y = logreg_eval(ps, x)
let loss = dsharp.crossEntropyLoss(y, target) // TODO: check this is the same as cross_entropy in utils.py
let predicted = y.argmax(dim=1)
let num_correct = dsharp.eq(predicted, target).sum().toInt32()
loss, num_correct
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment