This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.utils.data | |
class NumpyDataset(torch.utils.data.Dataset): | |
"""Dataset wrapping arrays. | |
Each sample will be retrieved by indexing array along the first dimension. | |
Arguments: | |
*arrays (numpy.array): arrays that have the same size of the first dimension. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dask.callbacks import Callback | |
from tqdm.auto import tqdm | |
class TQDMDaskProgressBar(Callback, object): | |
""" | |
A tqdm progress bar for dask. | |
Usage: | |
``` | |
with TQDMDaskProgressBar(): |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
In jupyter notebook simple logging to console | |
""" | |
import logging | |
import sys | |
logging.basicConfig(stream=sys.stdout, level=logging.INFO) | |
# Test | |
logger = logging.getLogger('LOGGER_NAME') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class AdamStepLR(torch.optim.Adam): | |
"""Combine Adam and lr_scheduler.StepLR so we can use it as a normal optimiser""" | |
def __init__(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, step_size=50000, gamma=0.5): | |
super().__init__(params, lr, betas, eps, weight_decay) | |
self.scheduler = torch.optim.lr_scheduler.StepLR(self, step_size, gamma) | |
def step(self): | |
self.scheduler.step() | |
return super().step() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def window_stack(x, window=4, pad=True): | |
""" | |
Stack along a moving window of a pytorch timeseries | |
Inputs: | |
tensor of dims (batches/time, channels) | |
pad: if true the left side will be padded to let the output match | |
Outputs: | |
if pad=True: a tensor of size (batches, channels, window) | |
else: tensor of size (batches-window, channels, window) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
class LayerNormConv2d(nn.Module): | |
""" | |
Layer norm the just works on the channel axis for a Conv2d | |
Ref: | |
- code modified from https://github.com/Scitator/Run-Skeleton-Run/blob/master/common/modules/LayerNorm.py | |
- paper: https://arxiv.org/abs/1607.06450 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.autograd as autograd | |
from torch.nn.parameter import Parameter | |
from torch.autograd import Variable | |
def r_d_max_func(itr): | |
"Default max r and d provider as recommended in paper." |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!doctype html> | |
<html lang="en"> | |
<head> | |
<meta charset="utf-8"> | |
<title>Profile report</title> | |
<meta name="description" content="Profile report generated by pandas-profiling. See GitHub."> | |
<meta name="author" content="pandas-profiling"> | |
<script src="https://ajax.googleapis.com/ajax/libs/jquery/1.11.3/jquery.min.js"></script> |