Skip to content

Instantly share code, notes, and snippets.

@piercelamb
Created December 20, 2022 17:55
Show Gist options
  • Save piercelamb/47deec3616416b6cdabcfe22b6553812 to your computer and use it in GitHub Desktop.
Save piercelamb/47deec3616416b6cdabcfe22b6553812 to your computer and use it in GitHub Desktop.
get_loaded_model
def get_loaded_model(config, model_name):
model_parent_path = f"{config.s3_parent_dir}/run_{config.run_num}/"
if folder_exists(bucket, model_parent_path + "tuning"):
model_parent_path += "tuning"
elif folder_exists(bucket, model_parent_path + "training"):
model_parent_path += "training"
else:
# it's a comparison job
model_parent_path += f"{model_name}_training"
torchscript_model = f"{model_name}_finetuned.pt"
model_s3_path = f"{model_parent_path}/{torchscript_model}"
local_model_path = os.path.join(os.getcwd(), torchscript_model)
if not os.path.isfile(local_model_path):
print(f"Downloading {torchscript_model}")
bucket.download_file(model_s3_path, local_model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load(local_model_path, map_location=device)
model = model.to(device)
model.eval()
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment