Skip to content

Instantly share code, notes, and snippets.

@data-panda
Created April 29, 2021 18:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save data-panda/c8b64b4d7a2605d2a37c0a799a473f7d to your computer and use it in GitHub Desktop.
Save data-panda/c8b64b4d7a2605d2a37c0a799a473f7d to your computer and use it in GitHub Desktop.
Dither Algo with NVTX Markers
def dither(self, img, dither_mask, pixel_loc):
nvtx.range_push("Dithering")
# Determine amount of padding
p = math.ceil(max(dither_mask.shape)/2)
# print("Padding with: ",p)
# Pad image (same amount in all directions, can be more efficient, but easy for now)
out = Func.pad(img, (p,p,p,p), "replicate")
# mini-batch size, channels, heigth, width
F, C, H, W = out.size()
m,n = dither_mask.shape
# Expand dithering mask to 3d
mask = dither_mask.unsqueeze_(-1).expand(m,n,F*C).permute(2, 0, 1)
# Stack all images, so that minibatches are wrapped into channels
out = out.view(-1, H, W)
# Iterate over height and width
nvtx.range_push("Outer for loop")
for i in range(p, H-p):
for j in range(p, W-p):
err = out[:,i,j].clone() #save for later
# threshold step
nvtx.range_push("Sign")
out[:,i,j] = out[:,i,j].sign() # When used in network
# out[:,i,j][out[:,i,j] > self.threshold] = 1
# out[:,i,j][out[:,i,j] <= self.threshold] = 0
nvtx.range_pop()
# Calculate quantization error (1d vector)
err.add(-out[:,i,j])
# Shape error vector into F*Cxkernel_size tensor so we can add it channelwise
err = err.unsqueeze_(-1).unsqueeze_(-1).expand(F*C,m,n)
# Multiply error with diffusion mask and add to original tensor
nvtx.range_push("Multiplication and addition")
out[:,i:i+m,j-pixel_loc:j-pixel_loc+n] = torch.add(out[:,i:i+m,j-pixel_loc:j-pixel_loc+n],torch.mul(mask, err))
nvtx.range_pop()
nvtx.range_pop()
nvtx.range_pop()
return(out[:,p:H-p, p:W-p].reshape(F,C, H-2*p, W-2*p))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment