Skip to content

Instantly share code, notes, and snippets.

@enijkamp
enijkamp / nan.py
Created July 12, 2022 21:46
nan.py
value_and_grad_f = jax.value_and_grad(train_apply_f, has_aux=False, allow_int=True)
grad_init = jax.tree_map(lambda x: jnp.zeros_like(x).astype(jnp.bfloat16), params_bf16)
def scan(f, init, xs):
carry = init
ys = []
for i in range(xs[0].shape[0]):
carry, y = f(carry, (xs[0][i], xs[1][i]))
ys.append(y)
return carry, jnp.stack(ys)
@enijkamp
enijkamp / leave_names.py
Created December 12, 2021 19:04
leave_names.py
def tree_flatten_with_names(pytree, is_leaf, path="", to_id=id):
id_to_name = {}
if getattr(pytree, "items", None):
for k, v in pytree.items():
k_path = f"{path}/{k}"
if is_leaf(v):
id_to_name[to_id(v)] = k_path
else:
id_to_name = {
**id_to_name,
@enijkamp
enijkamp / resharding.py
Last active August 2, 2021 00:25
resharding.py
'''
python3.8 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip setuptools
pip install -r requirements.txt
pip install --upgrade jax==0.2.12 jaxlib==0.1.67+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
'''
import os
# xla
def apply_reshard(pytree_params_in, pytree_params_out, shards_in, shards_out):
def override_dtype(x):
if x.dtype == np.dtype('V2'):
x.dtype = jnp.bfloat16
return x
def is_leaf(x):
return type(x) == np.ndarray
@enijkamp
enijkamp / bpe_ratio.py
Created July 14, 2021 21:09
bpe_ratio.py
import os
import io
import tempfile
import tensorflow as tf
import transformers
def write_to_file(writer, data):
feature = { 'text': tf.train.Feature(int64_list=tf.train.Int64List(value=data)) }
tf_example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(tf_example.SerializeToString())
@enijkamp
enijkamp / create_tf_records.py
Created July 13, 2021 07:25
create_tf_records.py
import sys
import os
import gzip
import json
import io
import argparse
import concurrent.futures
from tokenizers import Tokenizer
@enijkamp
enijkamp / tfrecordresumableloader.py
Created June 28, 2021 02:45
tfrecordresumableloader.py
import argparse
import numpy as np
import tensorflow as tf
class TFRecordResumableLoader:
def __init__(self, files, batch_size, batch_prefetch, parse_fn, map_fn=lambda x: x):
self.files = files
self.batch_size = batch_size
@enijkamp
enijkamp / clean.sh
Created April 26, 2020 09:26
how to clean unused latex images
vim /usr/local/texlive/2016/texmf.cnf
max_print_line=2000
error_line=254
half_error_line=238
#!/bin/bash
for image_file in $(find Part_1/Chpt_9_Generator_and_Descriptor/figures -type f)
{"input": [{"bbox": [206, 642, 206, 654], "text": "necessary", "id": 49}, {"bbox": [641, 175, 641, 185], "text": "AP", "id": 9}, {"bbox": [194, 554, 194, 564], "text": "suggest", "id": 37}, {"bbox": [10, 396, 10, 410], "text": "Drug", "id": 22}, {"bbox": [606, 106, 606, 117], "text": "Matthew", "id": 2}, {"bbox": [246, 551, 246, 562], "text": "factor.", "id": 38}, {"bbox": [73, 343, 73, 356], "text": "Arthurfort,", "id": 19}, {"bbox": [31, 275, 31, 286], "text": "Walter", "id": 11}, {"bbox": [29, 551, 29, 562], "text": "93.77", "id": 35}, {"bbox": [638, 141, 638, 151], "text": "4831", "id": 5}, {"bbox": [153, 507, 153, 518], "text": "Character", "id": 29}, {"bbox": [338, 639, 338, 651], "text": "break.", "id": 51}, {"bbox": [278, 595, 278, 607], "text": "staff.", "id": 43}, {"bbox": [106, 309, 106, 320], "text": "Knolls", "id": 15}, {"bbox": [46, 395, 46, 406], "text": "health", "id": 23}, {"bbox": [30, 639, 31, 652], "text": "720.8", "id": 46}, {"bbox": [617, 639, 617, 650], "text": "9685", "id": 47}, {"bbox
{"input": [{"x": 71, "y": 10, "width": 5, "id": 191}, {"x": 6, "y": 0, "width": 5, "id": 127}, {"x": 38, "y": 12, "width": 4, "id": 145}, {"x": 51, "y": 2, "width": 3, "id": 152}, {"x": 20, "y": 4, "width": 2, "id": 6}, {"x": 10, "y": 12, "width": 4, "id": 90}, {"x": 19, "y": 8, "width": 4, "id": 222}, {"x": 0, "y": 10, "width": 5, "id": 107}, {"x": 20, "y": 12, "width": 3, "id": 20}, {"x": 45, "y": 8, "width": 4, "id": 29}, {"x": 15, "y": 12, "width": 2, "id": 86}, {"x": 56, "y": 10, "width": 4, "id": 252}, {"x": 61, "y": 4, "width": 3, "id": 174}, {"x": 25, "y": 2, "width": 3, "id": 69}, {"x": 35, "y": 8, "width": 3, "id": 193}, {"x": 12, "y": 0, "width": 2, "id": 221}, {"x": 18, "y": 0, "width": 2, "id": 239}, {"x": 38, "y": 10, "width": 2, "id": 172}, {"x": 6, "y": 10, "width": 4, "id": 189}, {"x": 46, "y": 12, "width": 5, "id": 183}, {"x": 11, "y": 8, "width": 3, "id": 196}, {"x": 15, "y": 8, "width": 3, "id": 134}, {"x": 57, "y": 2, "width": 2, "id": 45}, {"x": 35, "y": 10, "width": 2, "id": 223}, {"x":