Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.optim.lr_scheduler import _LRScheduler | |
import numpy as np | |
class InterpolatingScheduler(_LRScheduler): | |
def __init__(self, optimizer, steps, lrs, scale='log', last_epoch=-1): | |
"""A scheduler that interpolates given values | |
Args: | |
- optimizer: pytorch optimizer | |
- steps: list or array with the x coordinates of the interpolated values |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.nn.modules.utils import _pair | |
class CausalConv2d(nn.Conv2d): | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, bias=True): | |
kernel_size = _pair(kernel_size) | |
stride = _pair(stride) | |
dilation = _pair(dilation) | |
if padding is None: | |
padding = [int((kernel_size[i] -1) * dilation[i]) for i in range(len(kernel_size))] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
print(datetime.datetime.utcnow().strftime('%Y%m%d_%H-%M-%S')) | |
import torch | |
print(torch._C._cudnn_version(), 'cudnn') | |
print(torch._C._cuda_getDriverVersion(), 'cuda driver') | |
print(torch._C._cuda_getCompiledVersion(), 'cuda compiled version') | |
print(torch._C._nccl_version(), 'nccl') | |
for i in range(torch.cuda.device_count()): | |
print('device %s:'%i, torch.cuda.get_device_properties(i)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data.sampler import Sampler | |
import itertools | |
class SequentialRandomSampler(Sampler): | |
"""Samples elements sequentially, starting from a random location. | |
For when you want to sequentially sampled a random subset | |
Usage: | |
loader = torch.utils.data.DataLoader( |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def isfinite(x): | |
""" | |
Quick pytorch test that there are no nan's or infs. | |
note: torch now has torch.isnan | |
url: https://gist.github.com/wassname/df8bc03e60f81ff081e1895aabe1f519 | |
""" | |
not_inf = ((x + 1) != x) | |
not_nan = (x == x) | |
return not_inf & not_nan |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# code for question on reddit https://www.reddit.com/r/MachineLearning/comments/8poc3z/r_blog_post_on_world_models_for_sonic/e0cwb5v/ | |
# from this | |
def forward(self, x): | |
self.lstm.flatten_parameters() | |
x = F.relu(self.fc1(x)) | |
z, self.hidden = self.lstm(x, self.hidden) | |
sequence = x.size()[1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Pytorch sampler that samples ordered indices from unordered sequences. | |
Good for use with dask and RNN's, because | |
1. Dask will slow down if sampling between chunks, so we must do one chunk at a time | |
2. RNN's need sequences so we must have seqences e.g. 1,2,3 | |
3. But RNN's train better with batches that are uncorrelated so we want each batch to be sequence from a different part of a chunk. | |
For example, given each chunk is `range(12)`. Our seq_len is 3. We might end up with these indices: | |
- [[1,2,3],[9,10,11],[4,5,6]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.utils.data | |
class NumpyDataset(torch.utils.data.Dataset): | |
"""Dataset wrapping arrays. | |
Each sample will be retrieved by indexing array along the first dimension. | |
Arguments: | |
*arrays (numpy.array): arrays that have the same size of the first dimension. |