Skip to content

Instantly share code, notes, and snippets.

@keunhong
Created March 23, 2017 04:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save keunhong/a9b71efdd6c2fc34b2bf33732527e6ba to your computer and use it in GitHub Desktop.
Save keunhong/a9b71efdd6c2fc34b2bf33732527e6ba to your computer and use it in GitHub Desktop.
continuous_histogram
import math
def normal(x, mu, sigma):
return 1.0 / (sigma * math.sqrt(2*math.pi)) * torch.exp(-(x - mu)**2 / (2*sigma**2))
def continuous_histogram(batch, bins=32):
hists = Variable(torch.zeros(batch.size(0), 3, bins).cuda())
binvals = Variable(torch.linspace(0, 1, bins).cuda())
# Expand values so we compute histogram in parallel.
binvals = binvals.view(1, 1, 1, bins).expand(*batch.size(), bins)
batch = batch.view(*batch.size(), 1).expand(*batch.size(), bins)
hist = normal(batch, binvals, 0.01).sum(dim=2).squeeze()
hist /= 10000 * bins
return hist
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment