This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This is taken from https://github.com/yunjey/pytorch-tutorial with just a few changes. | |
# Please see there for copyright and license information and use that copy. | |
import torch | |
import torch.nn as nn | |
import torchvision.datasets as dsets | |
import torchvision.transforms as transforms | |
from torch.autograd import Variable | |
import gc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
import torch | |
from torch.autograd import Variable | |
import torch.nn as nn | |
import gc | |
# helper function to get rss size, see stat(5) under statm. This is in pages (4k on my linux) | |
def memory_usage(): | |
return int(open('/proc/self/statm').read().split()[1]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.utils.data | |
from torchvision import datasets, transforms | |
class PartialDataset(torch.utils.data.Dataset): | |
def __init__(self, parent_ds, offset, length): | |
self.parent_ds = parent_ds | |
self.offset = offset | |
self.length = length | |
assert len(parent_ds)>=offset+length, Exception("Parent Dataset not long enough") | |
super(PartialDataset, self).__init__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from timeit import default_timer as time | |
import numpy as np | |
from numba import cuda | |
import os | |
os.environ['NUMBAPRO_LIBDEVICE']='/usr/lib/nvidia-cuda-toolkit/libdevice/' | |
os.environ['NUMBAPRO_NVVM']='/usr/lib/x86_64-linux-gnu/libnvvm.so.3.1.0' | |
import numpy | |
import torch | |
import ctypes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.autograd import Variable | |
def linear_with_sumsq(inp, weight, bias=None): | |
def provide_sumsq(inp,w,b): | |
def _h(i): | |
if not hasattr(w, 'grad_sumsq'): | |
w.grad_sumsq = 0 | |
w.grad_sumsq += ((i**2).t().matmul(inp**2))*i.size(0) | |
if b is not None: |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
indices = torch.LongTensor([[1,1,1], | |
[2,1,1]]) # must be two dimensional with one row per dimension | |
values = torch.arange(1,4) | |
size = torch.Size((3,3)) | |
a = torch.sparse.FloatTensor(indices, values, size) | |
b = torch.eye(3) | |
b += a | |
print (b) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from PIL import Image | |
from torch.utils.data import DataLoader | |
import torchvision | |
from torchvision import transforms, datasets | |
from torch.autograd import Variable | |
import torch.nn as nn | |
import torch.optim | |
import torch.backends.cudnn as cudnn; cudnn.benchmark = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload | |
import builtins | |
import math | |
import pickle | |
class dtype: ... | |
_dtype = dtype |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
csrc = """ | |
#include <torch/extension.h> | |
#include <THC/THCDeviceUtils.cuh> | |
#include <THC/THCGeneral.h> | |
#include "ATen/ATen.h" | |
#include "ATen/AccumulateType.h" | |
#include "ATen/cuda/CUDAContext.h" | |
using namespace at; | |
OlderNewer