Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tam17aki/491e676d82b5259bc0fd3bcda57bbf3b to your computer and use it in GitHub Desktop.
Save tam17aki/491e676d82b5259bc0fd3bcda57bbf3b to your computer and use it in GitHub Desktop.
diffusion_model_book_2_2_score_based_model.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/tam17aki/491e676d82b5259bc0fd3bcda57bbf3b/diffusion_model_book_2_2_score_based_model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ybGLrN9MeBRz"
},
"outputs": [],
"source": [
"from tqdm import tqdm_notebook as tqdm\n",
"import torch\n",
"\n",
"device = \"cpu\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zJNm-Rep08EY"
},
"outputs": [],
"source": [
"n_samples = int(1e6)\n",
"sigma = 0.1\n",
"\n",
"dist0 = torch.distributions.MultivariateNormal(torch.tensor([-2, -2], dtype=torch.float).to(device), sigma*torch.eye(2, dtype=torch.float).to(device))\n",
"samples0 = dist0.sample(torch.Size([n_samples//2]))\n",
" \n",
"dist1 = torch.distributions.MultivariateNormal(torch.tensor([2, 2], dtype=torch.float).to(device), sigma*torch.eye(2, dtype=torch.float).to(device))\n",
"samples1 = dist1.sample(torch.Size([n_samples//2]))\n",
"samples = torch.vstack((samples0, samples1))\n",
"\n",
"mean = torch.mean(samples, dim=0)\n",
"std = torch.std(samples, dim=0)\n",
"\n",
"normalized_samples = (samples - mean[None, :])/std[None, :]\n",
"\n",
"dataset = torch.utils.data.TensorDataset((normalized_samples))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 281
},
"id": "f8wwV52v4Cqr",
"outputId": "1e39d255-5461-4bbc-d736-753c54e03f19"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"%matplotlib inline\n",
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plot_samples = normalized_samples.cpu().numpy()\n",
"\n",
"plt.hist2d(plot_samples[:,0], plot_samples[:,1], range=((-2, 2), (-2, 2)), cmap='viridis', rasterized=False, bins=100, density=True)\n",
"plt.gca().set_aspect('equal', adjustable='box')\n",
"plt.xlim([-2, 2])\n",
"plt.ylim([-2, 2])\n",
"plt.title('Sample Density')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x0gDBxSZdq8c"
},
"outputs": [],
"source": [
"sigma_begin = 0.001\n",
"sigma_end = 1.0\n",
"T = 200\n",
"sigmas = torch.tensor(np.exp(np.linspace(np.log(sigma_begin), np.log(sigma_end), T))).float().to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "maXGG_Mdefn_"
},
"outputs": [],
"source": [
"def dsm_loss(score_model, samples, sigmas):\n",
" t = torch.randint(0, len(sigmas), (samples.shape[0],), device=sigmas.device)\n",
" used_sigmas = sigmas[t].view(samples.shape[0], *([1] * len(samples.shape[1:])))\n",
" noise = torch.randn_like(samples) * used_sigmas\n",
" perturbed_samples = samples + noise\n",
" target = - 1 / (used_sigmas ** 2) * noise\n",
" scores = score_model(perturbed_samples, used_sigmas)\n",
" target = target.view(target.shape[0], -1)\n",
" scores = scores.view(scores.shape[0], -1)\n",
" w = used_sigmas.squeeze(-1) ** 2\n",
" loss = ((scores - target) ** 2).sum(dim=-1) * w\n",
" return loss.mean()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-Lwg4kMVM03Z"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class ScoreModel(nn.Module):\n",
" def __init__(self, n_channels=2):\n",
" super(ScoreModel, self).__init__()\n",
"\n",
" self.model = nn.Sequential(\n",
" nn.Linear(n_channels + 1, 2*n_channels),\n",
" nn.ELU(),\n",
" nn.Linear(2*n_channels, 16*n_channels),\n",
" nn.ELU(),\n",
" nn.Linear(16*n_channels, 2*n_channels),\n",
" nn.ELU(),\n",
" nn.Linear(2*n_channels, n_channels),\n",
" )\n",
"\n",
" def forward(self, x, sigma):\n",
" x = torch.cat((x, sigma), dim=1)\n",
" y = self.model(x)\n",
" return y"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "P9QWytgpa0pn",
"outputId": "d6a8f58d-c1c5-41f4-b5e4-173175143409"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0 steps loss:2.10343599319458\n",
"1000 steps loss:2.138179063796997\n",
"2000 steps loss:1.8900189399719238\n",
"3000 steps loss:1.9138953685760498\n",
"4000 steps loss:1.9245710372924805\n",
"5000 steps loss:1.914085865020752\n",
"6000 steps loss:1.8996541500091553\n",
"7000 steps loss:1.831539273262024\n",
"8000 steps loss:1.8835861682891846\n",
"9000 steps loss:1.809353232383728\n",
"10000 steps loss:1.6475526094436646\n",
"11000 steps loss:1.6848489046096802\n",
"12000 steps loss:1.5485844612121582\n",
"13000 steps loss:1.7400761842727661\n",
"14000 steps loss:1.601453423500061\n",
"15000 steps loss:1.6501575708389282\n",
"16000 steps loss:1.5780878067016602\n",
"17000 steps loss:1.6547218561172485\n",
"18000 steps loss:1.5332990884780884\n",
"19000 steps loss:1.5731066465377808\n",
"20000 steps loss:1.6006646156311035\n",
"21000 steps loss:1.5818557739257812\n",
"22000 steps loss:1.4299019575119019\n",
"23000 steps loss:1.6383044719696045\n",
"24000 steps loss:1.6676623821258545\n",
"25000 steps loss:1.5119620561599731\n",
"26000 steps loss:1.522580623626709\n",
"27000 steps loss:1.462673306465149\n",
"28000 steps loss:1.533888578414917\n",
"29000 steps loss:1.4086380004882812\n",
"30000 steps loss:1.5557557344436646\n",
"31000 steps loss:1.4251868724822998\n",
"32000 steps loss:1.4696316719055176\n",
"33000 steps loss:1.461082100868225\n",
"34000 steps loss:1.487654209136963\n",
"35000 steps loss:1.4957597255706787\n",
"36000 steps loss:1.5456204414367676\n",
"37000 steps loss:1.7313734292984009\n",
"38000 steps loss:1.5402085781097412\n",
"39000 steps loss:1.450195550918579\n",
"40000 steps loss:1.4396100044250488\n",
"41000 steps loss:1.516808271408081\n",
"42000 steps loss:1.5489423274993896\n",
"43000 steps loss:1.4515758752822876\n",
"44000 steps loss:1.541172981262207\n",
"45000 steps loss:1.467519998550415\n",
"46000 steps loss:1.4105401039123535\n",
"47000 steps loss:1.4251112937927246\n",
"48000 steps loss:1.3971335887908936\n",
"49000 steps loss:1.5265257358551025\n",
"50000 steps loss:1.4675631523132324\n",
"51000 steps loss:1.520257830619812\n",
"52000 steps loss:1.3204165697097778\n",
"53000 steps loss:1.5094389915466309\n",
"54000 steps loss:1.4009146690368652\n",
"55000 steps loss:1.502292275428772\n",
"56000 steps loss:1.6551986932754517\n",
"57000 steps loss:1.566804051399231\n",
"58000 steps loss:1.7013672590255737\n",
"59000 steps loss:1.5082180500030518\n",
"60000 steps loss:1.4296739101409912\n",
"61000 steps loss:1.5015251636505127\n",
"62000 steps loss:1.5544426441192627\n",
"63000 steps loss:1.4355849027633667\n",
"64000 steps loss:1.4423151016235352\n",
"65000 steps loss:1.5445194244384766\n",
"66000 steps loss:1.4904658794403076\n",
"67000 steps loss:1.4084433317184448\n",
"68000 steps loss:1.502392053604126\n",
"69000 steps loss:1.458216905593872\n",
"70000 steps loss:1.5076521635055542\n",
"71000 steps loss:1.5410008430480957\n",
"72000 steps loss:1.5718374252319336\n",
"73000 steps loss:1.526256799697876\n",
"74000 steps loss:1.5020813941955566\n",
"75000 steps loss:1.5269545316696167\n",
"76000 steps loss:1.387231707572937\n",
"77000 steps loss:1.4536584615707397\n",
"78000 steps loss:1.4981789588928223\n",
"79000 steps loss:1.421187162399292\n",
"80000 steps loss:1.4645973443984985\n",
"81000 steps loss:1.4400806427001953\n",
"82000 steps loss:1.4696435928344727\n",
"83000 steps loss:1.6755788326263428\n",
"84000 steps loss:1.4436883926391602\n",
"85000 steps loss:1.4473791122436523\n",
"86000 steps loss:1.4851280450820923\n",
"87000 steps loss:1.5375834703445435\n",
"88000 steps loss:1.4770243167877197\n",
"89000 steps loss:1.5834132432937622\n",
"90000 steps loss:1.5765262842178345\n",
"91000 steps loss:1.5022810697555542\n",
"92000 steps loss:1.548335075378418\n",
"93000 steps loss:1.641385555267334\n",
"94000 steps loss:1.4702154397964478\n",
"95000 steps loss:1.5113106966018677\n",
"96000 steps loss:1.4681692123413086\n",
"97000 steps loss:1.568758487701416\n",
"98000 steps loss:1.4290732145309448\n",
"99000 steps loss:1.5507277250289917\n"
]
}
],
"source": [
"import torch\n",
"\n",
"batch_size = 512\n",
"n_steps = 100000\n",
"\n",
"dataloader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=True, num_workers=0)\n",
"dataloader_iter = iter(dataloader)\n",
"\n",
"score_model = ScoreModel().to(device)\n",
"\n",
"optimizer = torch.optim.Adam(score_model.parameters())\n",
"lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, total_steps=n_steps)\n",
"\n",
"\n",
"\n",
"for i in range(n_steps):\n",
" try:\n",
" x = next(dataloader_iter)[0]\n",
" except StopIteration:\n",
" dataloader_iter = iter(dataloader)\n",
" x = next(dataloader_iter)[0]\n",
" x = x.to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" loss = dsm_loss(score_model, x, sigmas)\n",
" loss.backward()\n",
" optimizer.step()\n",
" lr_scheduler.step()\n",
" if (i % 1000) == 0:\n",
" print(f\"{i} steps loss:{loss}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9P_sAS2Fpa-o"
},
"outputs": [],
"source": [
"def sbm_sample(n_samples, score_model, sigmas, alpha=0.1):\n",
" sigma_T = sigmas[-1]\n",
" x_0 = torch.randn(n_samples, 2) * sigma_T\n",
" x_tk = x_0\n",
" K = 200\n",
" for t in range(len(sigmas) -1, -1, -1):\n",
" sigma_t = sigmas[t]\n",
" alpha_t = alpha*(sigma_t**2)/(sigma_T**2)\n",
" print(f\"t:{t}, sigma_t:{sigma_t}, alpha_t:{alpha_t}\")\n",
" for k in range(K+1):\n",
" u_k = torch.randn(n_samples, 2)\n",
" if (k == K) and t == 0:\n",
" u_k[:, :] = 0.0\n",
" with torch.no_grad():\n",
" sigma_t_dup = torch.ones((n_samples, 1)) * sigma_t\n",
" score = score_model(x_tk, sigma_t_dup)\n",
" x_tk = x_tk + alpha_t * score + np.sqrt(2 * alpha_t) * u_k\n",
" return x_tk"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ixng3Ojiqp_J",
"outputId": "bfc4347c-b13f-4194-e205-1b598924362b"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"t:199, sigma_t:1.0, alpha_t:0.10000000149011612\n",
"t:198, sigma_t:0.965883195400238, alpha_t:0.09329303354024887\n",
"t:197, sigma_t:0.9329304099082947, alpha_t:0.08703591674566269\n",
"t:196, sigma_t:0.90110182762146, alpha_t:0.08119844645261765\n",
"t:195, sigma_t:0.8703591227531433, alpha_t:0.07575250416994095\n",
"t:194, sigma_t:0.8406652808189392, alpha_t:0.07067181169986725\n",
"t:193, sigma_t:0.8119844794273376, alpha_t:0.06593187898397446\n",
"t:192, sigma_t:0.7842822074890137, alpha_t:0.061509858816862106\n",
"t:191, sigma_t:0.7575250267982483, alpha_t:0.057384420186281204\n",
"t:190, sigma_t:0.731680691242218, alpha_t:0.053535666316747665\n",
"t:189, sigma_t:0.7067181468009949, alpha_t:0.049945052713155746\n",
"t:188, sigma_t:0.6826071739196777, alpha_t:0.0465952567756176\n",
"t:187, sigma_t:0.6593188047409058, alpha_t:0.04347012937068939\n",
"t:186, sigma_t:0.6368249654769897, alpha_t:0.040554605424404144\n",
"t:185, sigma_t:0.6150985956192017, alpha_t:0.03783462941646576\n",
"t:184, sigma_t:0.5941134095191956, alpha_t:0.035297077149152756\n",
"t:183, sigma_t:0.5738441944122314, alpha_t:0.032929714769124985\n",
"t:182, sigma_t:0.5542664527893066, alpha_t:0.030721131712198257\n",
"t:181, sigma_t:0.5353566408157349, alpha_t:0.028660673648118973\n",
"t:180, sigma_t:0.5170920491218567, alpha_t:0.026738420128822327\n",
"t:179, sigma_t:0.4994505047798157, alpha_t:0.024945080280303955\n",
"t:178, sigma_t:0.48241087794303894, alpha_t:0.02327202633023262\n",
"t:177, sigma_t:0.4659525752067566, alpha_t:0.021711179986596107\n",
"t:176, sigma_t:0.4500557780265808, alpha_t:0.020255019888281822\n",
"t:175, sigma_t:0.4347013235092163, alpha_t:0.018896525725722313\n",
"t:174, sigma_t:0.4198707044124603, alpha_t:0.017629140987992287\n",
"t:173, sigma_t:0.40554606914520264, alpha_t:0.016446761786937714\n",
"t:172, sigma_t:0.39171016216278076, alpha_t:0.015343685634434223\n",
"t:171, sigma_t:0.3783462643623352, alpha_t:0.014314589090645313\n",
"t:170, sigma_t:0.365438312292099, alpha_t:0.013354516588151455\n",
"t:169, sigma_t:0.3529707193374634, alpha_t:0.012458832934498787\n",
"t:168, sigma_t:0.34092849493026733, alpha_t:0.01162322424352169\n",
"t:167, sigma_t:0.32929712533950806, alpha_t:0.01084365975111723\n",
"t:166, sigma_t:0.31806257367134094, alpha_t:0.010116379708051682\n",
"t:165, sigma_t:0.307211309671402, alpha_t:0.009437879547476768\n",
"t:164, sigma_t:0.29673025012016296, alpha_t:0.008804883807897568\n",
"t:163, sigma_t:0.2866067588329315, alpha_t:0.008214343339204788\n",
"t:162, sigma_t:0.27682867646217346, alpha_t:0.007663411553949118\n",
"t:161, sigma_t:0.2673841714859009, alpha_t:0.007149429526180029\n",
"t:160, sigma_t:0.25826188921928406, alpha_t:0.006669920869171619\n",
"t:159, sigma_t:0.24945081770420074, alpha_t:0.00622257124632597\n",
"t:158, sigma_t:0.24094036221504211, alpha_t:0.0058052255772054195\n",
"t:157, sigma_t:0.2327202409505844, alpha_t:0.005415871273726225\n",
"t:156, sigma_t:0.2247805893421173, alpha_t:0.005052631255239248\n",
"t:155, sigma_t:0.21711179614067078, alpha_t:0.004713753238320351\n",
"t:154, sigma_t:0.20970463752746582, alpha_t:0.0043976036831736565\n",
"t:153, sigma_t:0.20255018770694733, alpha_t:0.00410265801474452\n",
"t:152, sigma_t:0.19563983380794525, alpha_t:0.0038274943362921476\n",
"t:151, sigma_t:0.18896523118019104, alpha_t:0.0035707857459783554\n",
"t:150, sigma_t:0.1825183480978012, alpha_t:0.003331294748932123\n",
"t:149, sigma_t:0.17629140615463257, alpha_t:0.0031078660394996405\n",
"t:148, sigma_t:0.17027691006660461, alpha_t:0.002899422775954008\n",
"t:147, sigma_t:0.16446761786937714, alpha_t:0.0027049598284065723\n",
"t:146, sigma_t:0.15885651111602783, alpha_t:0.0025235391221940517\n",
"t:145, sigma_t:0.1534368395805359, alpha_t:0.002354286378249526\n",
"t:144, sigma_t:0.1482020765542984, alpha_t:0.0021963855251669884\n",
"t:143, sigma_t:0.14314588904380798, alpha_t:0.0020490745082497597\n",
"t:142, sigma_t:0.13826221227645874, alpha_t:0.0019116438925266266\n",
"t:141, sigma_t:0.1335451602935791, alpha_t:0.0017834309255704284\n",
"t:140, sigma_t:0.12898902595043182, alpha_t:0.0016638169763609767\n",
"t:139, sigma_t:0.12458833307027817, alpha_t:0.001552225323393941\n",
"t:138, sigma_t:0.12033778429031372, alpha_t:0.0014481182442978024\n",
"t:137, sigma_t:0.1162322461605072, alpha_t:0.0013509935233741999\n",
"t:136, sigma_t:0.11226677894592285, alpha_t:0.0012603829381987453\n",
"t:135, sigma_t:0.10843659937381744, alpha_t:0.0011758495820686221\n",
"t:134, sigma_t:0.10473708808422089, alpha_t:0.0010969857685267925\n",
"t:133, sigma_t:0.10116379708051682, alpha_t:0.0010234114015474916\n",
"t:132, sigma_t:0.09771241247653961, alpha_t:0.0009547715890221298\n",
"t:131, sigma_t:0.09437878429889679, alpha_t:0.000890735536813736\n",
"t:130, sigma_t:0.09115888178348541, alpha_t:0.0008309941622428596\n",
"t:129, sigma_t:0.08804883807897568, alpha_t:0.0007752598030492663\n",
"t:128, sigma_t:0.08504489064216614, alpha_t:0.0007232633652165532\n",
"t:127, sigma_t:0.08214343339204788, alpha_t:0.0006747543811798096\n",
"t:126, sigma_t:0.07934096455574036, alpha_t:0.0006294988561421633\n",
"t:125, sigma_t:0.07663410902023315, alpha_t:0.0005872786859981716\n",
"t:124, sigma_t:0.07401960343122482, alpha_t:0.0005478902021422982\n",
"t:123, sigma_t:0.07149428874254227, alpha_t:0.0005111432983539999\n",
"t:122, sigma_t:0.06905513256788254, alpha_t:0.00047686113975942135\n",
"t:121, sigma_t:0.06669919937849045, alpha_t:0.0004448783292900771\n",
"t:120, sigma_t:0.06442363560199738, alpha_t:0.0004150404711253941\n",
"t:119, sigma_t:0.06222570687532425, alpha_t:0.00038720385055057704\n",
"t:118, sigma_t:0.06010276824235916, alpha_t:0.0003612342698033899\n",
"t:117, sigma_t:0.05805225670337677, alpha_t:0.0003370064659975469\n",
"t:116, sigma_t:0.056071698665618896, alpha_t:0.00031440352904610336\n",
"t:115, sigma_t:0.054158713668584824, alpha_t:0.0002933166397269815\n",
"t:114, sigma_t:0.052310992032289505, alpha_t:0.00027364399284124374\n",
"t:113, sigma_t:0.05052630975842476, alpha_t:0.00025529079721309245\n",
"t:112, sigma_t:0.04880251735448837, alpha_t:0.00023816856264602393\n",
"t:111, sigma_t:0.047137532383203506, alpha_t:0.0002221947070211172\n",
"t:110, sigma_t:0.04552935063838959, alpha_t:0.0002072921779472381\n",
"t:109, sigma_t:0.043976034969091415, alpha_t:0.0001933891762746498\n",
"t:108, sigma_t:0.04247571527957916, alpha_t:0.00018041864677798003\n",
"t:107, sigma_t:0.041026581078767776, alpha_t:0.00016831803077366203\n",
"t:106, sigma_t:0.03962688520550728, alpha_t:0.0001570290041854605\n",
"t:105, sigma_t:0.038274943828582764, alpha_t:0.00014649714285042137\n",
"t:104, sigma_t:0.036969125270843506, alpha_t:0.00013667163148056716\n",
"t:103, sigma_t:0.03570786118507385, alpha_t:0.00012750514724757522\n",
"t:102, sigma_t:0.034489624202251434, alpha_t:0.00011895342322532088\n",
"t:101, sigma_t:0.03331294655799866, alpha_t:0.00011097524838987738\n",
"t:100, sigma_t:0.032176416367292404, alpha_t:0.00010353217658121139\n",
"t:99, sigma_t:0.03107866272330284, alpha_t:9.658833005232736e-05\n",
"t:98, sigma_t:0.030018357560038567, alpha_t:9.011018119053915e-05\n",
"t:97, sigma_t:0.028994228690862656, alpha_t:8.406653068959713e-05\n",
"t:96, sigma_t:0.02800503931939602, alpha_t:7.842822378734127e-05\n",
"t:95, sigma_t:0.02704959735274315, alpha_t:7.316807023016736e-05\n",
"t:94, sigma_t:0.02612675167620182, alpha_t:6.826072058174759e-05\n",
"t:93, sigma_t:0.025235392153263092, alpha_t:6.368249887600541e-05\n",
"t:92, sigma_t:0.02437444217503071, alpha_t:5.9411344409454614e-05\n",
"t:91, sigma_t:0.023542864248156548, alpha_t:5.542664439417422e-05\n",
"t:90, sigma_t:0.022739658132195473, alpha_t:5.1709208491956815e-05\n",
"t:89, sigma_t:0.021963853389024734, alpha_t:4.8241086915368214e-05\n",
"t:88, sigma_t:0.021214518696069717, alpha_t:4.5005581341683865e-05\n",
"t:87, sigma_t:0.020490746945142746, alpha_t:4.198707392788492e-05\n",
"t:86, sigma_t:0.01979166828095913, alpha_t:3.9171016396721825e-05\n",
"t:85, sigma_t:0.019116440787911415, alpha_t:3.654383181128651e-05\n",
"t:84, sigma_t:0.01846424862742424, alpha_t:3.409284909139387e-05\n",
"t:83, sigma_t:0.017834309488534927, alpha_t:3.1806262995814905e-05\n",
"t:82, sigma_t:0.017225859686732292, alpha_t:2.9673023163923062e-05\n",
"t:81, sigma_t:0.016638169065117836, alpha_t:2.7682868676492944e-05\n",
"t:80, sigma_t:0.016070527955889702, alpha_t:2.582618617452681e-05\n",
"t:79, sigma_t:0.015522253699600697, alpha_t:2.4094035325106233e-05\n",
"t:78, sigma_t:0.014992684125900269, alpha_t:2.247805787192192e-05\n",
"t:77, sigma_t:0.01448118221014738, alpha_t:2.0970464902347885e-05\n",
"t:76, sigma_t:0.013987131416797638, alpha_t:1.9563985915738158e-05\n",
"t:75, sigma_t:0.013509934768080711, alpha_t:1.8251834262628108e-05\n",
"t:74, sigma_t:0.013049019500613213, alpha_t:1.7027690773829818e-05\n",
"t:73, sigma_t:0.012603829614818096, alpha_t:1.588565282872878e-05\n",
"t:72, sigma_t:0.012173827737569809, alpha_t:1.4820208889432251e-05\n",
"t:71, sigma_t:0.01175849512219429, alpha_t:1.3826221220369916e-05\n",
"t:70, sigma_t:0.01135733351111412, alpha_t:1.2898902241431642e-05\n",
"t:69, sigma_t:0.010969857685267925, alpha_t:1.2033778148179408e-05\n",
"t:68, sigma_t:0.010595601983368397, alpha_t:1.1226677997910883e-05\n",
"t:67, sigma_t:0.010234113782644272, alpha_t:1.0473708243807778e-05\n",
"t:66, sigma_t:0.00988495908677578, alpha_t:9.771241820999421e-06\n",
"t:65, sigma_t:0.00954771600663662, alpha_t:9.115888133237604e-06\n",
"t:64, sigma_t:0.009221978485584259, alpha_t:8.504488505423069e-06\n",
"t:63, sigma_t:0.008907354436814785, alpha_t:7.934096174722072e-06\n",
"t:62, sigma_t:0.008603464812040329, alpha_t:7.4019603744091e-06\n",
"t:61, sigma_t:0.008309941738843918, alpha_t:6.905513146193698e-06\n",
"t:60, sigma_t:0.008026433177292347, alpha_t:6.442362973757554e-06\n",
"t:59, sigma_t:0.007752597332000732, alpha_t:6.010276592860464e-06\n",
"t:58, sigma_t:0.007488104049116373, alpha_t:5.607170351140667e-06\n",
"t:57, sigma_t:0.007232633884996176, alpha_t:5.231099294178421e-06\n",
"t:56, sigma_t:0.006985879968851805, alpha_t:4.880251708527794e-06\n",
"t:55, sigma_t:0.006747544277459383, alpha_t:4.5529354792961385e-06\n",
"t:54, sigma_t:0.006517339497804642, alpha_t:4.247571268933825e-06\n",
"t:53, sigma_t:0.00629498902708292, alpha_t:3.962688879255438e-06\n",
"t:52, sigma_t:0.0060802241787314415, alpha_t:3.696912699524546e-06\n",
"t:51, sigma_t:0.00587278651073575, alpha_t:3.4489621612010524e-06\n",
"t:50, sigma_t:0.005672425962984562, alpha_t:3.2176417334994767e-06\n",
"t:49, sigma_t:0.005478901322931051, alpha_t:3.00183614854177e-06\n",
"t:48, sigma_t:0.00529197882860899, alpha_t:2.800504034894402e-06\n",
"t:47, sigma_t:0.005111433565616608, alpha_t:2.612675189084257e-06\n",
"t:46, sigma_t:0.004937048070132732, alpha_t:2.437444436509395e-06\n",
"t:45, sigma_t:0.004768611863255501, alpha_t:2.2739659470971674e-06\n",
"t:44, sigma_t:0.004605921916663647, alpha_t:2.121451643688488e-06\n",
"t:43, sigma_t:0.004448782652616501, alpha_t:1.979166654564324e-06\n",
"t:42, sigma_t:0.004297004546970129, alpha_t:1.8464248796590255e-06\n",
"t:41, sigma_t:0.004150404594838619, alpha_t:1.7225859210157068e-06\n",
"t:40, sigma_t:0.004008806310594082, alpha_t:1.6070528090494918e-06\n",
"t:39, sigma_t:0.003872038796544075, alpha_t:1.4992684782555443e-06\n",
"t:38, sigma_t:0.0037399372085928917, alpha_t:1.3987130387249636e-06\n",
"t:37, sigma_t:0.00361234275624156, alpha_t:1.30490195715538e-06\n",
"t:36, sigma_t:0.003489101305603981, alpha_t:1.2173827599326614e-06\n",
"t:35, sigma_t:0.0033700643107295036, alpha_t:1.1357333278283477e-06\n",
"t:34, sigma_t:0.0032550885807722807, alpha_t:1.0595601906970842e-06\n",
"t:33, sigma_t:0.003144035581499338, alpha_t:9.884960263661924e-07\n",
"t:32, sigma_t:0.0030367712024599314, alpha_t:9.221980121765228e-07\n",
"t:31, sigma_t:0.0029331662226468325, alpha_t:8.603464607404021e-07\n",
"t:30, sigma_t:0.0028330960776656866, alpha_t:8.026433420127432e-07\n",
"t:29, sigma_t:0.0027364399284124374, alpha_t:7.488103506148036e-07\n",
"t:28, sigma_t:0.0026430815923959017, alpha_t:6.985880531829025e-07\n",
"t:27, sigma_t:0.002552908146753907, alpha_t:6.517340125355986e-07\n",
"t:26, sigma_t:0.0024658110924065113, alpha_t:6.080224466131767e-07\n",
"t:25, sigma_t:0.0023816856555640697, alpha_t:5.672426937053388e-07\n",
"t:24, sigma_t:0.002300430089235306, alpha_t:5.291978482091508e-07\n",
"t:23, sigma_t:0.0022219468373805285, alpha_t:4.937048174724623e-07\n",
"t:22, sigma_t:0.0021461411379277706, alpha_t:4.605921901656984e-07\n",
"t:21, sigma_t:0.00207292172126472, alpha_t:4.297004636555357e-07\n",
"t:20, sigma_t:0.0020022003445774317, alpha_t:4.008806229194306e-07\n",
"t:19, sigma_t:0.0019338917918503284, alpha_t:3.739937426416873e-07\n",
"t:18, sigma_t:0.0018679136410355568, alpha_t:3.489101345621748e-07\n",
"t:17, sigma_t:0.001804186380468309, alpha_t:3.255088643072668e-07\n",
"t:16, sigma_t:0.0017426334088668227, alpha_t:3.0367712611223396e-07\n",
"t:15, sigma_t:0.0016831803368404508, alpha_t:2.833096175436367e-07\n",
"t:14, sigma_t:0.0016257556853815913, alpha_t:2.643081700171024e-07\n",
"t:13, sigma_t:0.0015702900709584355, alpha_t:2.465810950980085e-07\n",
"t:12, sigma_t:0.0015167169040068984, alpha_t:2.300430281820809e-07\n",
"t:11, sigma_t:0.0014649713411927223, alpha_t:2.1461410426582006e-07\n",
"t:10, sigma_t:0.0014149913331493735, alpha_t:2.002200574224844e-07\n",
"t:9, sigma_t:0.001366716343909502, alpha_t:1.867913539399524e-07\n",
"t:8, sigma_t:0.001320088398642838, alpha_t:1.742633344292699e-07\n",
"t:7, sigma_t:0.001275051268748939, alpha_t:1.6257557433618786e-07\n",
"t:6, sigma_t:0.001231550588272512, alpha_t:1.5167169920005108e-07\n",
"t:5, sigma_t:0.0011895340867340565, alpha_t:1.4149912885841331e-07\n",
"t:4, sigma_t:0.001148951007053256, alpha_t:1.3200885007336183e-07\n",
"t:3, sigma_t:0.0011097524547949433, alpha_t:1.2315506126014952e-07\n",
"t:2, sigma_t:0.0010718912817537785, alpha_t:1.1489509432749401e-07\n",
"t:1, sigma_t:0.0010353218531236053, alpha_t:1.0718913756591064e-07\n",
"t:0, sigma_t:0.0010000000474974513, alpha_t:1.0000001537946446e-07\n"
]
}
],
"source": [
"samples_pred = sbm_sample(n_samples=100000, score_model=score_model, sigmas=sigmas)\n",
"samples_pred = samples_pred.cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 281
},
"id": "dtTR7L35w4DK",
"outputId": "23907bcb-80c4-4009-c3e3-1dfdb2621b87"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"%matplotlib inline\n",
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.hist2d(samples_pred[:,0], samples_pred[:,1], range=((-2, 2), (-2, 2)), cmap='viridis', rasterized=False, bins=100, density=True)\n",
"plt.gca().set_aspect('equal', adjustable='box')\n",
"plt.xlim([-2, 2])\n",
"plt.ylim([-2, 2])\n",
"plt.title('Predicted Sample Density')\n",
"plt.show()"
]
}
],
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment