Skip to content

Instantly share code, notes, and snippets.

@shawwn
shawwn / libtpujesus.c
Created November 2, 2021 06:41
An example of building a custom "stub" libtpu.so library, with the ultimate goal of implementing your own "TPU" device for JAX.
/* libtpujesus.c
Copyright 2021 Shawn Presser
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
@finetunej
finetunej / to_hf_weights.py
Created August 17, 2021 12:24
For converting trained gpt-j checkpoints into a pytorch Hugging Face format.
####
# run with 'help' arg for usage.
####
"""
python3.8 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip setuptools
pip install -r requirements.txt
def apply_reshard(pytree_params_in, pytree_params_out, shards_in, shards_out):
def override_dtype(x):
if x.dtype == np.dtype('V2'):
x.dtype = jnp.bfloat16
return x
def is_leaf(x):
return type(x) == np.ndarray
def apply_reshard(pytree_params_in, pytree_params_out, shards_in, shards_out):
def override_dtype(x):
if x.dtype == np.dtype('V2'):
x.dtype = jnp.bfloat16
return x
def is_leaf(x):
return type(x) == np.ndarray