Skip to content

Instantly share code, notes, and snippets.

@edraizen
Last active March 13, 2020 19:46
Show Gist options
  • Save edraizen/8fd362e759132e154c4f54efe1709aae to your computer and use it in GitHub Desktop.
Save edraizen/8fd362e759132e154c4f54efe1709aae to your computer and use it in GitHub Desktop.
import torch
import numpy as np
import sparseconvnet as scn
from more_itertools import pairwise
from torch.nn.parallel._functions import Scatter
class Batch(object):
def __init__(self, indices, data, truth, chunk_sizes=None, dim=3):
assert isinstance(indices, torch.Tensor), "indices must be tensor"
assert isinstance(data, torch.Tensor), "data must be tensor"
assert isinstance(truth, torch.Tensor), "truth must be tensor"
assert indices.size()[0]==data.size()[0], "indices and data must have same length"
assert indices.size()[1]==dim+1, "indices must have sample number in last col"
self.data = [indices, data]
self.truth = truth
self.chunk_sizes = chunk_sizes
def scatter(self, target_gpus, dim=0):
n_gpus = len(target_gpus)
n_samples = self.data[0][:, -1].unique().size()[0]
n_samples_gpu = np.floor(n_samples/n_gpus)
_, pts_per_sample = torch.unique(a[:, -1], sorted=False,
return_counts=True, dim=0)
starts = list(range(0, len(s), int(n_samples_per_gpu)))
if starts[-1]<len(pts_per_sample):
starts.append(len(pts_per_sample))
self.chunk_sizes = torch.Tensor([pts_per_sample[i:j].sum().item() \
for i, j in pairwise(starts)])
start_stop = [0]+self.chunk_sizes.cumsum(0).int().tolist()
idx = torch.fmod(self.data[0][:, -1], n_samples_gpu)
data = Scatter.apply(target_gpus, self.chunk_sizes, dim, self.data[1])
chunks = []
for (start, stop), x in zip(start_stop, data):
chunk = Batch(idx[start:stop], x, self.truth[start:stop], self.chunk_sizes, dim)
chunks.append(chunk)
return chunks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment