Skip to content

Instantly share code, notes, and snippets.

@piercelamb
Created December 19, 2022 23:01
Show Gist options
  • Save piercelamb/9aaccac45cfed0d2113c192a86cdfae9 to your computer and use it in GitHub Desktop.
Save piercelamb/9aaccac45cfed0d2113c192a86cdfae9 to your computer and use it in GitHub Desktop.
post_training_files
# below only necessary if we're using tracking
accelerator.end_training()
if accelerator.is_main_process:
run_config_local_path = write_run_config(hyperparams['train_batch_size'], hyperparams['eval_batch_size'],
hyperparams['learning_rate'], hyperparams['epochs'])
local_best_model = os.path.join(os.getcwd(), f"{hyperparams['model_name']}_finetuned.pth")
accelerator.save(best_model.state_dict(), local_best_model)
local_torchscript_model_path = os.path.join(os.getcwd(), f"{hyperparams['model_name']}_finetuned.pt")
convert_to_torchscript(local_torchscript_model_path,best_model, train_data, hyperparams)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment