Skip to content

Instantly share code, notes, and snippets.

@bchess
Last active April 25, 2024 21:00
Show Gist options
  • Save bchess/1fb72e88385a2286b2c14e73237f3ed6 to your computer and use it in GitHub Desktop.
Save bchess/1fb72e88385a2286b2c14e73237f3ed6 to your computer and use it in GitHub Desktop.
Example tensorizer serialization in a subprocess to avoid GIL contention
import torch
from tensorizer import TensorSerializer
from transformers import AutoModelForCausalLM
import torch.multiprocessing as mp
def do_serialize(uri, model):
serializer = TensorSerializer(uri)
serializer.write_module(model)
serializer.close()
def main():
model_ref = "EleutherAI/gpt-j-6B"
dest = "gpt-j-6B.tensors"
model = AutoModelForCausalLM.from_pretrained(
model_ref,
revision="float16",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
mp.set_start_method('spawn')
p = mp.Process(target=do_serialize, args=(dest, model))
p.start()
p.join()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment