Skip to content

Instantly share code, notes, and snippets.

@niklasschmitz
niklasschmitz / install-CUDA-docker-nvidia-docker.sh
Last active December 2, 2018 12:31 — forked from dte/install-CUDA-docker-nvidia-docker.sh
Install CUDA, Docker, and Nvidia Docker on a new Paperspace GPU machine with Ubuntu 18.04
#!/bin/bash
# 1. Install CUDA
echo "Installing CUDA..."
# Only install if CUDA is not already installed.
if ! sudo dpkg-query -W cuda; then
# The 18.04 installer
sudo curl -O http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-repo-ubuntu1804_10.0.130-1_amd64.deb
sudo dpkg -i ./cuda-repo-ubuntu1804_10.0.130-1_amd64.deb
sudo apt-get update
sudo apt-get install cuda -y
@niklasschmitz
niklasschmitz / jaxpr_graph.py
Last active June 24, 2024 17:53 — forked from mattjj/grad_graph.py
visualizing jaxprs
import jax
from jax import core
from graphviz import Digraph
import itertools
styles = {
'const': dict(style='filled', color='goldenrod1'),
'invar': dict(color='mediumspringgreen', style='filled'),
'outvar': dict(style='filled,dashed', fillcolor='indianred1', color='black'),
@niklasschmitz
niklasschmitz / kernelregression.dx
Last active August 24, 2021 17:52
Kernel regression in dex-lang
import plot
-- Conjugate gradients solver
def solve (mat:m=>m=>Float) (b:m=>Float) : m=>Float =
x0 = for i:m. 0.0
ax = mat **. x0
r0 = b - ax
(xOut, _, _) = fold (x0, r0, r0) $
\s:m (x, r, p).
ap = mat **. p
@niklasschmitz
niklasschmitz / nlsolve_rrule_gmres.jl
Last active February 9, 2022 07:32
NLsolve ChainRules implicit differentiation
using NLsolve
using Zygote
using ChainRulesCore
using IterativeSolvers
using LinearMaps
using SparseArrays
using LinearAlgebra
using BenchmarkTools
using Random
Random.seed!(1234)
@niklasschmitz
niklasschmitz / torch2jax_dlpack_bug.py
Last active September 30, 2021 13:28
JAX - PyTorch dlpack conversion bug
import torch
import jax
import jax.dlpack
import torch.utils.dlpack
from jax.config import config
config.update("jax_enable_x64", True)
def jax2torch(x):
return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x))
@niklasschmitz
niklasschmitz / rrule_from_frule.jl
Last active October 2, 2021 12:56
Getting an rrule from an frule in ChainRules.jl
# rrule from frule (transposition)
using Zygote
using ChainRulesCore
using LinearAlgebra
function f(x)
a = sin.(x)
b = sum(a)
c = b * a
return c
@niklasschmitz
niklasschmitz / dither.dx
Created August 22, 2022 16:56
Dither in dex-lang feat. @adrhill
import plot
import png
key = new_key 1234
Height = Fin 100
Width = Fin 300
img_ = for i:Height. for j:Width. n_to_f ((ordinal i) + (ordinal j))
img = img_ / (n_to_f (size Height + size Width))
:html matshow img
:t img
@niklasschmitz
niklasschmitz / camera.jl
Last active September 10, 2022 10:32
Testimage camera.ppm
using Images
# CC0 license: https://commons.wikimedia.org/wiki/File:Wikicommons_kamera_rollei35_rrt877.jpg
file = download("https://upload.wikimedia.org/wikipedia/commons/thumb/3/38/Wikicommons_kamera_rollei35_rrt877.jpg/640px-Wikicommons_kamera_rollei35_rrt877.jpg")
img = load(file)
open("camera.ppm", "w+") do io
println(io, "P6")
println(io, size(img, 2), " ", size(img, 1))
println(io, 255)
for row in eachrow(img)
@niklasschmitz
niklasschmitz / heatmap.dx
Last active February 8, 2023 09:39
Code snippet for colormaps by linear interpolation and 2D heatmaps in dex-lang
import png
' ## Colormaps
Here we model colormaps as sequences of RGB colors which are then
piecewise linearly interpolated in the interval [0,1] upon evaluation.
RGB = (Fin 3)=>Float
cmap_viridis = [[253.0, 231.0, 37.0]
,[ 94.0, 201.0, 98.0]