Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
#!/bin/bash
set -eux -o pipefail
# Example usage (parallel processing w/ 10 jobs):
# find -name '*.pth' | xargs -P 10 -n 1 ./py2_to_py3.sh
py2=~/venv/py2
py3=~/venv/py3
if [[ ! -f ${py2}/bin/python ]]; then
python2 -m virtualenv -p python2 --system-site-packages ${py2}
${py2}/bin/pip install torch==0.4.0
fi
if [[ ! -f ${py3}/bin/python ]]; then
python3 -m virtualenv -p python3 --system-site-packages ${py3}
${py3}/bin/pip install torch==0.4.0
fi
in_file=${1}
if [[ -f ${in_file}.py3 ]]; then
exit 0
fi
tmp_file=$(mktemp -d)/tmp.pth
${py2}/bin/python - <<EOF
import cPickle as pickle
import torch
print("Load...")
with open("${in_file}", "rb") as f:
weights = torch.load(f)
print("Save...")
with open("${tmp_file}", "wb") as f:
pickle.dump(weights, f)
EOF
${py3}/bin/python - <<EOF
import pickle
import torch
def decode(c):
# Decode object in Python3 that was saved from Python2.
# https://www.dlology.com/blog/how-to-load-python-2-pytorch-checkpoint-in-python-3-1/
if isinstance(c, bytes):
c = c.decode()
elif hasattr(c, "__len__") and len(c) > 0:
if isinstance(c, list) and isinstance(c[0], bytes):
c = [decode(x) for x in c]
elif isinstance(c, dict):
c = {decode(k): decode(v) for k, v in c.items()}
return c
print("Load...")
with open("${tmp_file}", "rb") as f:
weights = pickle.load(f, encoding="bytes")
print("Decode")
weights = decode(weights)
print("Save...")
torch.save(decode(weights), "${in_file}.py3")
EOF
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment