This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# batch, time, bands, height, width | |
B, T, C, H, W = x.shape | |
C_o = 32 # hyperparam | |
F = lambda i, o, g=1: | |
nn.Conv2d(in_channels=i, | |
out_channels=o, | |
kernel_size=3, | |
padding='same', | |
groups=g) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author: Freddie Kalaitzis | |
# License: MIT | |
# Source: https://gist.github.com/alkalait/c99213c164df691b5e37cd96d5ab9ab2 | |
import functools | |
import itertools | |
import os | |
import re | |
import warnings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import Tensor | |
def is_broadcastable(x: Tensor, y: Tensor) -> bool: | |
""" Are the shapes of two tensors compatible for broadcasting? """ | |
if not x.ndim == y.ndim: | |
return False | |
n_same_dim = (torch.as_tensor(x.shape) == torch.as_tensor(y.shape)).sum() | |
return (int(n_same_dim) - x.ndim) <= 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author: Freddie Kalaitzis | |
# License: MIT | |
# Source: https://gist.github.com/alkalait/1497032fb601997efd9be4b90dddc63b | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import seaborn as sns | |
import torch | |
import xarray as xr | |
sns.set_style('white') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy | |
def shape(name : str) -> None: | |
''' | |
yaarr = torch.ones(25, 1, 1, 5, 5) # or np.ones | |
shape('yaarr') | |
>> yaarr : (25, 1, 1, 5, 5) | |
''' | |
val = eval(name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch # 1.3.1 | |
import torch.nn as nn | |
import torchvision # 0.4.2 | |
# for example | |
resnet = torchvision.models.resnet50(pretrained=False) | |
# strip away its FC layer | |
resnet_noFC = nn.Sequential(*list(resnet.children())[:-1]) | |
# Suppose you want to use it for binary classification. |