Skip to content

Instantly share code, notes, and snippets.

@daskol
daskol / sync-multiple.py
Last active April 19, 2022 16:00
Benchmark parallel reading of large files in JAX/FLAX
#!/usr/bin/env python
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from hashlib import sha256
from os import system
from time import monotonic
from tqdm import tqdm
@daskol
daskol / jax-flax-common.py
Created April 14, 2022 15:15
Common routines for deep learning with JAX/FLAX.
"""Module common defines common routines and training utils for JAX/Flax
environment.
"""
import logging
import flax
import jax
import jax.numpy as jnp
@daskol
daskol / bench-entropy.py
Created February 20, 2022 21:21
Naive implementation of cross-entropy beats library one (flax/optax)
"""This script performs benchmarking default implementation of cross-entropy
(in flax/optax) and naive one in plain JAX. One can run the script with the
code below.
$ mv bench-entropy.{py,ipy}
$ ipython bench-entropy.ipy
naive: 63.6 µs ± 4.27 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
optax: 67.3 µs ± 3.98 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Naive implementation is faster a bit on Nvidia V100 as well as user-end CPU.
@daskol
daskol / git-credential-pass
Last active January 22, 2022 22:07
Simple git credential helper in Python
#!/usr/bin/env python3
"""Simple script is a git credential helper which extracts password from
password store (pass).
"""
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
from subprocess import check_output
from sys import stdin, stdout
from typing import Optional, TextIO
@daskol
daskol / tensorboard2pandas.py
Created January 9, 2022 10:12
Load TensorBoard logging files to Pandas.
#!/usr/bin/env python3
"""Simple script to extract metrics calculated by HuggingFace tansformers from
TensorBoard logs.
"""
import pandas as pd
import tensorboard as tb
import tensorboard.data_compat
import tensorflow as tf
@daskol
daskol / tree_util.py
Last active January 1, 2022 19:11
JAX-like routines for module transformations in PyTorch (see jax.tree_util package).
"""Module tree_util implements routines for inplace transformations on PyTorch
modules as trees. It provides JAX-like API (see jax.tree_util package).
>>> from transformers import RobertaModel
>>> model = RobertaModel.from_pretrained('roberta-base')
>>> converted = map_module(module=model,
>>> func=lambda x, _: convert_linear(x, MyFancyLinear))
"""
import re
@daskol
daskol / visualise-backward-pass.py
Last active December 3, 2021 20:36
Visualization of backward pass graph for simple model in PyTorch
#!/usr/bin/evn python3
# Run this script and then the command below.
#
# dot -Tpng -ograph.png graph.dot
#
import torch as T
from torchviz import make_dot
model = T.nn.Sequential(
@daskol
daskol / pytorch-graph-summary.py
Created December 3, 2021 15:49
Add PyTorch graph summary to TensorBoard.
import torch as T
model = T.nn.Sequential()
model.add_module('W0', T.nn.Linear(128, 10))
model.add_module('tanh', T.nn.Tanh())
model.add_module('W1', T.nn.Linear(10, 5))
model.add_module('tanh', T.nn.Tanh())
model.add_module('W2', T.nn.Linear(5, 1))
input = T.randn(16, 128)
@daskol
daskol / jax-free.py
Created August 12, 2021 14:26
Free JAX/XLA buffers of size exceeded threshold
import jax
def collect(threshold=256 * 1024 ** 2):
backend = jax.lib.xla_bridge.get_backend()
freed = 0
for buf in backend.live_buffers():
if buf.nbytes >= threshold:
buf += buf.nbytes
buf.delete()
@daskol
daskol / wait4pid.py
Last active June 17, 2021 15:47
Event multiplexing on process identifiers (PIDs)
#!/usr/bin/env python3
"""This script demonstrates event multiplexing on process identifiers (PIDs).
More specifically, we issue file descriptor (FD) from PID and. Then we wait
events on the descriptor. Note that polling such events is feasible because of
obtaining FD from PID with pidfd_open() system call. This feature was
introduced in 5.3 (Sep 2019). A usage example is below.
$ ./wait4pid.py 1337 &
$ kill -9 1337
process with pid 1337 exited