This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def sinusoidal(positions, features=16, periods=10000): | |
"""Encode `positions` using sinusoidal positional encoding | |
Args: | |
positions: tensor of positions | |
features: half the number of features per position | |
periods: used frequencies for the sinusoidal functions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
__all__ = ['softmax_mask'] | |
class SoftmaxMask(torch.autograd.Function): | |
"""Differentiable mask for logits before a softmax operation""" | |
@staticmethod | |
def forward(ctx, *args, **kwargs): | |
inputs, = args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def randg(*args, like=None, **kwargs): | |
"""Sample from Gumbel(location=0, scale=1)""" | |
generator = kwargs.pop('generator', None) | |
requires_grad = kwargs.pop('requires_grad', False) | |
if like is None: | |
samples = torch.empty(*args, **kwargs) | |
else: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tempfile | |
import urllib.request | |
import importlib.util | |
from pathlib import Path | |
def import_from_url(url): | |
"""Import a module from a given URL""" | |
with tempfile.TemporaryDirectory() as path: | |
path = Path(path) / Path(url).name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gc | |
import math | |
import time | |
import datetime | |
from contextlib import contextmanager | |
import torch | |
class Monitor: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ssl | |
import urllib | |
from pathlib import Path | |
import torch | |
from torch.utils.data import Dataset | |
from torchvision.datasets.utils import extract_archive, check_integrity | |
import h5py | |
import pandas as pd |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
def chunk_dim(tensor, chunks, dim=0): | |
"""Split a dimension of a tensor into two dimensions""" | |
shape = list(tensor.shape) | |
shape[dim] //= chunks | |
shape.insert(dim, chunks) | |
return tensor.view(shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""YOLOv3 object detector.""" | |
import math | |
from pathlib import Path | |
from urllib.request import urlopen | |
from PIL import Image | |
from PIL import ImageColor, ImageOps | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
from typing import Tuple, Optional | |
from contextlib import contextmanager | |
import torch | |
from torch.utils import benchmark | |
# @torch.jit.script | |
def nearest_neighbors( |