Skip to content

Instantly share code, notes, and snippets.

@thomasweng15
Last active December 11, 2022 22:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thomasweng15/f0de0a3a0e91e5d50c867274dae0d821 to your computer and use it in GitHub Desktop.
Save thomasweng15/f0de0a3a0e91e5d50c867274dae0d821 to your computer and use it in GitHub Desktop.
Load Python 3 PyTorch Networks in Python 2 for ROS deployment
import sys
import argparse
import torch
def load_model_txt(model, path):
print('Loading...')
data_dict = {}
fin = open(path, 'r')
i = 0
odd = 1
prev_key = None
while True:
s = fin.readline().strip()
if not s:
break
if odd:
prev_key = s
else:
# print('Iter', i)
val = eval(s)
if type(val) != type([]):
data_dict[prev_key] = torch.FloatTensor([eval(s)])[0]
else:
data_dict[prev_key] = torch.FloatTensor(eval(s))
i += 1
odd = (odd + 1) % 2
# Replace existing values with loaded
own_state = model.state_dict()
print('Items:', len(own_state.items()))
for k, v in data_dict.items():
if not k in own_state:
print('Parameter', k, 'not found in own_state!!!')
else:
try:
own_state[k].copy_(v)
except:
print('Key:', k)
print('Old:', own_state[k])
print('New:', v)
sys.exit(0)
print('Model loaded')
if __name__ == '__main__':
# Run this script in python2 or call it from your ROS node
parser = argparse.ArgumentParser()
parser.add_argument('--txt_path')
args = parser.parse_args()
txt_path = args.txt_path
# TODO: Initialize model
# model =
load_model_txt(model, txt_path)
import sys
import argparse
import torch
def save_model_txt(model, path):
fout = open(path, 'w')
# for k, v in model.state_dict().items():
for k, v in model.items():
fout.write(str(k) + '\n')
if not isinstance(v, torch.Tensor):
fout.write(str(v) + '\n')
else:
fout.write(str(v.tolist()) + '\n')
fout.close()
if __name__ == '__main__':
# Run this script in python3
parser = argparse.ArgumentParser()
parser.add_argument('--weight_path')
args = parser.parse_args()
weight_path = args.weight_path
model = torch.load(weight_path)
save_model_txt(model, weight_path.replace('.pt', '.txt'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment