This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Input, Subtract, Dense, Lambda | |
from keras.models import Model | |
import keras.backend as K | |
def build_siamese_network(encoder, input_shape): | |
input_1 = Input(input_shape) | |
input_2 = Input(input_shape) | |
# `encoder` is any predefined network that maps a single sample | |
# into an embedding space. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class NShotTaskSampler(Sampler): | |
def __init__(self, | |
dataset: torch.utils.data.Dataset, | |
episodes_per_epoch: int = None, | |
n: int = None, | |
k: int = None, | |
q: int = None, | |
num_tasks: int = 1, | |
fixed_tasks: List[Iterable[int]] = None): | |
"""PyTorch Sampler subclass that generates batches of n-shot, k-way, q-query tasks. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn.utils import clip_grad_norm_ | |
def matching_net_episode(model: Module, | |
optimiser: Optimizer, | |
loss_fn: Loss, | |
x: torch.Tensor, | |
y: torch.Tensor, | |
n_shot: int, | |
k_way: int, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def proto_net_episode(model: Module, | |
optimiser: Optimizer, | |
loss_fn: Callable, | |
x: torch.Tensor, | |
y: torch.Tensor, | |
n_shot: int, | |
k_way: int, | |
q_queries: int, | |
distance: str, | |
train: bool): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
def replace_grad(parameter_gradients, parameter_name): | |
"""Creates a backward hook function that replaces the calculated gradient | |
with a precomputed value when .backward() is called. | |
See | |
https://pytorch.org/docs/stable/autograd.html?highlight=hook#torch.Tensor.register_hook | |
for more info |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def projected_gradient_descent(model, x, y, loss_fn, num_steps, step_size, step_norm, eps, eps_norm, | |
clamp=(0,1), y_target=None): | |
"""Performs the projected gradient descent attack on a batch of images.""" | |
x_adv = x.clone().detach().requires_grad_(True).to(x.device) | |
targeted = y_target is not None | |
num_channels = x.shape[1] | |
for i in range(num_steps): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchvision import transforms, datasets | |
from torch import nn, optim | |
from torch.utils.data import DataLoader | |
import torch.nn.functional as F | |
class Net(nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = nn.Conv2d(1, 32, kernel_size=5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
movement_filters = torch.Tensor([ | |
[ | |
[0, 1, 0], | |
[0, 0, 0], | |
[0, 0, 0], | |
], | |
[ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
bodies = torch.zeros((2, 1, 7, 7)) | |
heads = torch.zeros((2, 1, 7, 7)) | |
num_envs = bodies.size(0) | |
# Initialise body as shown in diagram | |
bodies[:, :, 3, 2] = 1 | |
bodies[:, :, 3, 3] = 2 | |
bodies[:, :, 2, 3] = 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os.path | |
import psutil | |
import pyarrow as pa | |
import numpy as np | |
from pyarrow import parquet as pq | |
import time | |
WINDOW_LENGTH = 1000 | |
N = 1000000 |
OlderNewer