Skip to content

Instantly share code, notes, and snippets.

@bnsh
Created March 13, 2021 13:06
Show Gist options
  • Save bnsh/1a2deeb024e01919a62ec7e99d5e10bb to your computer and use it in GitHub Desktop.
Save bnsh/1a2deeb024e01919a62ec7e99d5e10bb to your computer and use it in GitHub Desktop.
I'm trying to compute the mean value over a batch of masked values... I'd like to do it without loops and dictionaries as I've done..
#! /usr/bin/env python3
# vim: expandtab shiftwidth=4 tabstop=4
"""
My application outputs a 2 dimensional quantity, but is batched.. I generate masks for the outputs that I want it to predict in the absence of having that data.
The masking works. As in, I can do input[masks] = 0, and I can also read output[masks] to get the value of only the masked outputs and I can do
loss = crit(output[masks], target[masks]) and get the loss for only the values that are in the mask.
My issue is that I want to sum over the masks.. So, in the demo below, I have a batch size of 4 and
for batch 0, I'm setting (0, 0) = 2
for batch 1, I'm setting (0, 0) = 3
for batch 2, I'm setting (1, 1) = 5
for batch 3, I'm setting (1, 0) = 7
(But, let's pretend those were losses.)
I would like somehow to get
summed_losses = torch.FloatTensor([
[5, 0],
[7, 5]
])
Is there a way this can be accomplished? (Really what I want is the mean, so that's what the below does... But, I'd like a
much less horrendous version, if possible...)
"""
from collections import defaultdict
import torch
def main():
initial = torch.zeros((4, 2, 2))
masks = tuple([torch.LongTensor((0, 1, 2, 3)), torch.LongTensor((0, 0, 1, 1)), torch.LongTensor((0, 0, 1, 0))])
awful = [tuple(x) for x in torch.stack(masks, dim=1)[:, 1:].cpu().numpy().tolist()]
values = torch.FloatTensor([2, 3, 5, 7])
summed = torch.zeros((2, 2))
denominator = torch.zeros((2, 2))
for coord, value in zip(awful, values):
summed[coord] += value
denominator[coord] += 1
print("-- Sum --")
print(summed)
print("\n-- Denominator --")
print(denominator)
print("\n-- Mean --")
epsilon = 1e-5
print((summed)/(denominator+(summed == 0) * epsilon))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment