This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
if __name__ == '__main__': | |
X = torch.randn(100) | |
out_shape = (100,100) | |
idxs = torch.randint(high=100, size=out_shape).long() | |
assert torch.abs(X[idxs] - X.index_select(0, idxs.view(-1)).view(*out_shape)).max() < 1e-3 | |
from timeit import timeit | |
setup = 'import torch; X = torch.randn(100); out_shape=(100,100); idxs = torch.randint(high=100, size=out_shape).long()' | |
print("X[idxs]: ", timeit("_ = X[idxs]", setup=setup, number=100)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# implementation of https://arxiv.org/abs/1504.04788 | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import xxhash | |
class HashFunction(object): | |
"""Hash function as described in the paper, maps a key (i,j) to a natural number | |
in {1,...,K_L}""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
import numpy as np | |
from scipy import optimize | |
from obj import PyTorchObjective |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from functools import reduce | |
import math | |
def is_power(n): | |
# https://stackoverflow.com/a/29480710/6937913 | |
n = n/2 | |
if n == 2: | |
return True | |
elif n > 2: |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
from io import StringIO | |
def write_bitstream(fname, bits): | |
# bits are a string of ones and zeros, based on this | |
# stackoverflow answer: https://stackoverflow.com/a/16888829/6938913 | |
# was broken due to utf-8 encoding using up to 4 bytes: https://stackoverflow.com/a/33349765/6937913 | |
sio = StringIO(bits) | |
with open(fname, 'wb') as f: | |
while 1: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tqdm import tqdm | |
class TrainingProgress(tqdm): | |
"""Make the progress bar store progress as it displays it when the log method is called.""" | |
def log(self, *args, **kwargs): | |
if 'trace' not in self.__dict__.keys(): | |
self.trace = {} | |
# store anything with an integer or float datatype in the trace dictionary | |
# with global index as the key | |
for k in kwargs: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import numpy as np | |
from scipy.fftpack import dct | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim |