Skip to content

Instantly share code, notes, and snippets.

View jekbradbury's full-sized avatar

James Bradbury jekbradbury

View GitHub Profile
@pfitzseb
pfitzseb / keyboardmacro.jl
Created June 16, 2019 07:46
keyboard macro prototype
using REPL
using REPL.LineEdit
macro keyboard()
quote
debugprompt(@__MODULE__, Base.@locals)
println()
end
end
@willtebbutt
willtebbutt / toy_chainrules_rmad.jl
Last active April 16, 2019 20:46
Toy tape-based reverse-mode AD with minimal Cassette usage.
#
# This uses the Nabla.jl-style interception mechanism whereby
# we wrap things that are to be differentiated w.r.t. in a
# thin wrapper. There are lots of thing that you can't
# propoagate derivative information through with this kind of
# approach without quite a lot of extra machinery, but the
# examples at the bottom do work.
#
using ChainRules, Cassette
struct CachedConv
conv::Conv
cache::Ref{Tuple}
end
CachedConv(c::Conv) = CachedConv(c, ())
Flux.@treelike CachedConv
function (m::CachedConv)(x::AbstractArray)
# Has the user changed batch size on us? If so, clear our cache and re-up!
if !isempty(m.cache[]) && size(m.cache[][2], 4) != size(x, 4)
@staticfloat
staticfloat / zygote_batch_norm.jl
Last active April 11, 2019 04:01
Zygote BatchNorm implementation
using Zygote, Statistics, Flux
# We modify (the implementation of) batchnorm to be more ammenable to CPUs pretending to be TPUs.
struct ZygoteBatchNorm{F,V,W}
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ::W # moving std
ϵ::Float32
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@staticfloat
staticfloat / fix_nvidia
Created March 8, 2019 23:47
Fixer script for NVidia driver installs on dkms-compatible systems
#!/usr/bin/env bash
function red_echo()
{
tput setaf 1
echo "$*"
tput sgr0
}
if [[ "$*" == *--help* ]]; then
@rntz
rntz / Runtime.hs
Created February 14, 2019 21:38
A seminaïve, mildly optimizing compiler from modal Datafun to Haskell, in OCaml.
-- The Datafun runtime.
module Runtime where
import qualified Data.Set as Set
import Data.Set (Set)
class Eq a => Preord a where
(<:) :: a -> a -> Bool
class Preord a => Semilat a where
from functools import partial
import numpy.random as npr
import jax.numpy as np
from jax import lax
from jax import grad, pjit, papply
### set up some synthetic data
@maxbennedich
maxbennedich / sparse_show.jl
Last active January 16, 2019 14:28
Braille plot sparse show
const BRAILLE = split("⠀⠁⠂⠄⡀⠈⠐⠠⢀", "") .|> s -> Int(s[1])
function show_any_nonzero(S::SparseMatrixCSC; maxw = displaysize(stdout)[2], maxh = displaysize(stdout)[1]-3)
h,w = size(S)
h > 4maxh && (w = max(1, (w*4maxh+h÷2)÷h); h = 4maxh)
w > 2maxw && (h = max(1, (h*2maxw+w÷2)÷w); w = 2maxw)
P = fill(BRAILLE[1], (w+3)÷2, (h+3)÷4)
P[end, :].=10
@inbounds for c = 0:w-1, r = 0:h-1
_anynz(S, r*S.m÷h+1, c*S.n÷w+1, (r+1)*S.m÷h, (c+1)*S.n÷w) &&
@pfitzseb
pfitzseb / TraceCalls.jl
Created December 17, 2018 15:10
TraceCalls.jl with Cassette
module TraceCalls
using Cassette
mutable struct Trace
level::Int
cutoff::Int
end
Cassette.@context TraceCtx