Skip to content

Instantly share code, notes, and snippets.

@litagin02
Last active January 18, 2024 01:51
Show Gist options
  • Save litagin02/c6ab8a35c2b2b779c632ca820b805267 to your computer and use it in GitHub Desktop.
Save litagin02/c6ab8a35c2b2b779c632ca820b805267 to your computer and use it in GitHub Desktop.
学習したpthファイルから事前学習モデルsafetensorsを作るやつ
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"jvs_root = \"Data/jvs/models\"\n",
"step = 50000\n",
"\n",
"g_model = torch.load(f\"{jvs_root}/G_{step}.pth\", map_location=\"cpu\")\n",
"d_model = torch.load(f\"{jvs_root}/D_{step}.pth\", map_location=\"cpu\")\n",
"dur_model = torch.load(f\"{jvs_root}/DUR_{step}.pth\", map_location=\"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"emb_g.weight\n"
]
}
],
"source": [
"g_dict = {}\n",
"for key in g_model[\"model\"].keys():\n",
" if key.startswith(\"emb_g\"):\n",
" print(key)\n",
" else:\n",
" g_dict[key] = g_model[\"model\"][key]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"d_dict = {}\n",
"for key in d_model[\"model\"].keys():\n",
" d_dict[key] = d_model[\"model\"][key]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"dur_dict = {}\n",
"for key in dur_model[\"model\"].keys():\n",
" dur_dict[key] = dur_model[\"model\"][key]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from safetensors.torch import save_file\n",
"\n",
"\n",
"save_file(g_dict, f\"G_0.safetensors\")\n",
"save_file (d_dict, f\"D_0.safetensors\")\n",
"save_file (dur_dict, f\"DUR_0.safetensors\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment