import torch
import torchio as tio
image4d = tio.Resample()(tio.datasets.FPG().dmri)
collater = torch.utils.data.DataLoader([image4d])
batch = next(iter(collater))
tensor = batch[tio.DATA].float()
dropped = torch.nn.Dropout3d()(tensor)
for i, channel in enumerate(dropped[0]):
print(i, channel.sum())
The output was:
0 tensor(1.5286e+09)
1 tensor(0.)
2 tensor(0.)
3 tensor(1.5318e+09)
4 tensor(1.0110e+09)
5 tensor(1.0110e+09)
6 tensor(1.0075e+09)
7 tensor(0.)
8 tensor(0.)
9 tensor(0.)
10 tensor(0.)
11 tensor(1.0131e+09)
12 tensor(1.0117e+09)
13 tensor(0.)
14 tensor(0.)
15 tensor(0.)
16 tensor(1.0101e+09)
17 tensor(0.)
18 tensor(1.0141e+09)
19 tensor(0.)