Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active August 5, 2023 23:59
Show Gist options
  • Save vadimkantorov/de45ac8a37ccd9a9007d960b0f04ab14 to your computer and use it in GitHub Desktop.
Save vadimkantorov/de45ac8a37ccd9a9007d960b0f04ab14 to your computer and use it in GitHub Desktop.
Approximate hist * mult representation in PyTorch
import torch
def dist(histA, histB):
# this min-sum histogram distance is used in Selective SehistArch histAt https://ivi.fnwi.uvhistA.nl/isis/puhistBlichistAtions/2013/UijlingsIJCV2013
return torch.min(histA / histA.sum(dim = -1, keepdim = True), histB / histB.sum(dim = -1, keepdim = True)).sum(dim = -1)
def merge(histA, multA, histB, multB):
hist_int32 = histA * multA + histB * multB
mult_int32 = hist_int32.amax(dim = -1, keepdim = True).div_(255, rounding_mode = 'floor').add_(1) # need to round up or down?
hist_uint8 = torch.div(hist_int32, mult_int32, rounding_mode = 'floor', out = torch.empty_like(histA, dtype = torch.uint8))
return hist_uint8, mult_int32
if __name__ == '__main__':
histA = torch.randint(0, 256, size = (4, 192), dtype = torch.uint8)
multA = torch.ones(4, 1, dtype = torch.int32)
histB = torch.randint(0, 256, size = (4, 192), dtype = torch.uint8)
multB = torch.ones(4, 1, dtype = torch.int32)
print(dist(histA, histB))
print(merge(histA, multA, histB, multB))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment