Skip to content

Instantly share code, notes, and snippets.

@rejuvyesh
rejuvyesh / lru.py
Created April 13, 2023 02:45 — forked from Ryu1845/lru.py
Linear Recurrent Unit (LRU) from the paper ["Resurrecting Recurrent Neural Networks for Long Sequences"](https://arxiv.org/abs/2303.06349)
"""
Simplified Implementation of the Linear Recurrent Unit
------------------------------------------------------
We present here a simplified JAX implementation of the Linear Recurrent Unit (LRU).
The state of the LRU is driven by the input $(u_k)_{k=1}^L$ of sequence length $L$
according to the following formula (and efficiently parallelized using an associative scan):
$x_{k} = \Lambda x_{k-1} +\exp(\gamma^{\log})\odot (B u_{k})$,
and the output is computed at each timestamp $k$ as follows: $y_k = C x_k + D u_k$.
In our code, $B,C$ follow Glorot initialization, with $B$ scaled additionally by a factor 2
to account for halving the state variance by taking the real part of the output projection.
using SimpleChains
function f(x)
N = Base.isqrt(length(x))
A = reshape(view(x, 1:N*N), (N,N))
expA = exp(A)
vec(expA)
end
T = Float32;
D = 2 # 2x2 matrices
# install tinycudann via
# pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
import torch
import tinycudann as tcnn
import time
class TCNNMatrixExponentEstimator1(torch.nn.Module):
def __init__(self, hidden=16) -> None:
super().__init__()
@rejuvyesh
rejuvyesh / stresstest_dlpack.jl
Created February 7, 2022 22:12
DLPack reproduce segfault
using PyCall
using DLPack
using Test
using Zygote
using ChainRulesCore
torch = pyimport("torch")
functorch = pyimport("functorch")
dlpack = pyimport("torch.utils.dlpack")
py"""
@rejuvyesh
rejuvyesh / stresstest_jax_dlpack.jl
Last active February 9, 2022 17:32
DLPACk segfault reproduce on CUDA+Jax
using PyCall
using CUDA
using DLPack
using Test
#using Zygote
#using ChainRulesCore
@show DLPack.PYCALL_NOOP_DELETER
jax = pyimport("jax")
@rejuvyesh
rejuvyesh / gcc-phat.m
Last active November 26, 2021 15:17
gcc phat algo
for i = 1:3
tau(i) = gcc_phat(x{i}(1,:), x{i}(2,:))
end
function tau = gcc_phat(sig1, sig2)
% Find FFT for the signals
fft1 = fft(sig1, fftSize(sig1));
fft2 = fft(sig2, fftSize(sig2));
% Find R(\Tau)
G12 = fft1.*conj(fft2);
using Flux
qdim = 2
nn = Chain(Dense(qdim, 32, tanh), Dense(32, 2));
q = rand(2, 5);
function jac(x)
o = nn(x)
return reduce(hcat, [o[:, i] for i in 1:size(x)[end]])
end
#!/usr/bin/env python3
#
# File: test_dircol.py
#
import numpy as np
import torch
from optimalcontrol.dircolproblem import DIRCOLProblem
from mechamodlearn import utils
using MuJoCo
modelfile = "test/humanoid.xml"
pm = mj_loadXML(modelfile) # Raw C pointer to mjModel
pd = mj_makeData(pm) # Raw C pointer to mjData
m, d = mj.mapmujoco(pm, pd) # wrap with our jlModel, jlData types
# we can manipulate data in the raw C structs now
nq = mj.get(m, :nq)
@rejuvyesh
rejuvyesh / amc.py
Created December 18, 2015 22:10
CMU MoCap Handling
#!/usr/bin/env python
#
# File: amc.py
#
# Created: Friday, December 18 2015 by rejuvyesh <mail@rejuvyesh.com>
# License: GNU GPL 3 <http://www.gnu.org/copyleft/gpl.html>
#
from __future__ import print_function