Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active July 23, 2024 14:50
Show Gist options
  • Save vadimkantorov/86c3a46bf25bed3ad45d043ae86fff57 to your computer and use it in GitHub Desktop.
Save vadimkantorov/86c3a46bf25bed3ad45d043ae86fff57 to your computer and use it in GitHub Desktop.
Tensor-backed immutable string array and list-of-dicts to be used in PyTorch Dataset classes to work around copied shared memory-pages when using Python lists of strings https://github.com/pytorch/pytorch/issues/13246
import math
import typing
import torch
class StringArray:
def __init__(self, strings : typing.List[str], encoding : typing.Literal['ascii', 'utf_16_le', 'utf_32_le'] = 'utf_16_le'):
strings = list(strings)
self.encoding = encoding
self.multiplier = dict(ascii = 1, utf_16_le = 2, utf_32_le = 4)[encoding]
self.data = torch.ByteTensor(torch.ByteStorage.from_buffer(''.join(strings).encode(encoding)))
self.cumlen = torch.LongTensor(list(map(len, strings))).cumsum(dim = 0).mul_(self.multiplier)
assert int(self.cumlen[-1]) == len(self.data), f'[{encoding}] is not enough to hold characters, use a larger character class'
def tolist(self):
data_bytes, cumlen = bytes(self.data), self.cumlen.tolist()
return [data_bytes[0:cumlen[0]].decode(self.encoding)] + [data_bytes[start:end].decode(self.encoding) for start, end in zip(cumlen[:-1], cumlen[1:])]
def __getitem__(self, i):
return bytes(self.data[(self.cumlen[i - 1] if i >= 1 else 0) : self.cumlen[i]]).decode(self.encoding)
def __len__(self):
return len(self.cumlen)
class DictArray:
def __init__(self, dicts : typing.List[dict], types : typing.Dict[str, typing.ClassVar] = {}, *, batch_size : int = 1024, string_encoding : typing.Literal['ascii', 'utf_16_le', 'utf_32_le'] = 'utf_16_le', ints_dtype = torch.int64):
dicts = list(dicts)
numel = len(dicts)
assert numel > 0
self.tensors = {k : t(numel) for k, t in types.items() if t != StringArray and t != IntsArray}
string_lists = {k : [None] * numel for k, t in types.items() if t == StringArray}
ints_lists = {k : [None] * numel for k, t in types.items() if t == IntsArray}
temp_lists = {k : [None] * batch_size for k in self.tensors}
for b in range(math.ceil(numel / batch_size)):
for i, t in enumerate(dicts[b * batch_size : (b + 1) * batch_size]):
for k in temp_lists:
temp_lists[k][i] = t[k]
for k in string_lists:
string_lists[k][b * batch_size + i] = t[k]
for k in ints_lists:
ints_lists[k][b * batch_size + i] = t[k]
for k, v in temp_lists.items():
res = self.tensors[k][b * batch_size : (b + 1) * batch_size]
res.copy_(torch.as_tensor(v[:len(res)], dtype = self.tensors[k].dtype))
self.string_arrays = {k : StringArray(v, encoding = string_encoding) for k, v in string_lists.items()}
self.ints_arrays = {k : IntsArray(v, dtype = ints_dtype) for k, v in ints_lists.items()}
def __getitem__(self, i):
return dict(**{k : v[i].item() for k, v in self.tensors.items()}, **{k : v[i] for k, v in self.string_arrays.items()}, **{k : v[i] for k, v in self.ints_arrays.items()})
def __len__(self):
return len(next(iter(self.tensors.values()))) if len(self.tensors) > 0 else len(next(iter(self.string_arrays.values())))
class NamedTupleArray(DictArray):
def __init__(self, namedtuples, *args, **kwargs):
super().__init__([t._asdict() for t in namedtuples], *args, **kwargs)
self.namedtuple = type(next(iter(namedtuples)))
def __getitem__(self, index):
return self.namedtuple(**super().__getitem__(index))
class IntsArray:
def __init__(self, ints, dtype = torch.int64):
tensors = [torch.as_tensor(t, dtype = dtype) for t in ints]
self.data = torch.cat(tensors)
self.cumlen = torch.tensor(list(map(len, tensors)), dtype = torch.int64).cumsum(dim = 0)
def __getitem__(self, i):
return self.data[(self.cumlen[i - 1] if i >= 1 else 0) : self.cumlen[i]]
def __len__(self):
return len(self.cumlen)
if __name__ == '__main__':
a = StringArray(['asd', 'def'])
print('len = ', len(a))
print('data = ', list(a))
a = DictArray([dict(a = 1, b = 'def'), dict(a = 2, b = 'klm')], types = dict(a = torch.LongTensor, b = StringArray))
print('len = ', len(a))
print('data = ', list(a))
@rgtjf
Copy link

rgtjf commented Jul 10, 2021

Nice code! Thanks for providing these!

I found the code is extremely slow to run this:

# coding=utf-8

import torch
import logging
import json

logger = logging.getLogger(__name__)

class StringArray:
  def __init__(self, strings,
               encoding= 'utf_16_le'):
    strings = list(strings)
    self.encoding = encoding
    self.multiplier = dict(ascii=1, utf_16_le=2, utf_32_le=4)[encoding]
    self.data = torch.ByteTensor(torch.ByteStorage.from_buffer(''.join(strings).encode(encoding)))
    self.cumlen = torch.LongTensor(list(map(len, strings))).cumsum(dim=0).mul_(self.multiplier)
    assert int(self.cumlen[-1]) == len(
      self.data), f'[{encoding}] is not enough to hold characters, use a larger character class'

  def __getitem__(self, i):
    return bytes(self.data[(self.cumlen[i - 1] if i >= 1 else 0): self.cumlen[i]]).decode(self.encoding)

  def __len__(self):
    return len(self.cumlen)

  def tolist(self):
    data_bytes, cumlen = bytes(self.data), self.cumlen.tolist()
    return [data_bytes[0:cumlen[0]].decode(self.encoding)] \
           + [data_bytes[start:end].decode(self.encoding)
              for start, end in zip(cumlen[:-1], cumlen[1:])]


if __name__ == '__main__':
  data = list(map(str, range(30000000)))
  data = [json.dumps({datum:datum*100, 'value': int(datum)}) for datum in data]

  logger.info('load data')
  for i in range(len(data)):
    try:
      datum = json.loads(data[i])
    except:
      logger.info('error, {}, {}'.format(i, data[i]))
    if i < 10 or i % 100000 == 0:
      logger.info('idx={}, {}'.format(i, 'log'))

Could someone provide any solutions?

@alsm168
Copy link

alsm168 commented Jul 12, 2021

Hi,how to save a dict with tensor array?

@vadimkantorov
Copy link
Author

vadimkantorov commented Jul 12, 2021

@rgtjf Your code doesn't use StringArray at all. So I can't help. That being said, accessing StringArray in a loop is going to be slow, and for this usage you need to call tolist() first and then consume its output. It's being slow because tensor item access is terribly slow in PyTorch. What fixes this issue is implementing a tolist() method that calls tolist() on internal tensors and then populates the resulting dict array (Like done in StringArray). If you need fast access to a range of elements, please use tolist() or implement similar helpers that would first call tolist() on the internal tensors.

@alsm168 Could you rephrase your question? If you want to serialise it, you'd need to serialise the internal tensor arrays such as self.cumlen and self.data

@alsm168
Copy link

alsm168 commented Jul 15, 2021

@vadimkantorov I use the StringArray as the below,but I found the memory usage is growing slowly,Is there anything wrong?thanks

import torch
from tensorbackeddictarray import StringArray

class SyllableDataset(torch.utils.data.Dataset):
def init(self, data_list):
self.data_list = data_list

def __len__(self):
    return len(self.data_list)

def __getitem__(self, idx):
    utt, wav_path, syllable_label = self.data_list[idx].split('\t')
    return utt

def get_dataset():
map_file = r'big_syllable_data.data'
map_data_list = []
with open(map_file, 'r', encoding='utf-8') as fid:
for line in fid:
utt, wav_path, syllable_label = line.strip().split('\t')
map_data_list.append(f'{utt}\t{wav_path}\t{syllable_label}')
map_data_list = StringArray(map_data_list)
mapdataset = SyllableDataset(map_data_list)
map_loader = torch.utils.data.DataLoader(mapdataset,
batch_size=128,
num_workers=4,
pin_memory=False,
shuffle=False)
for batch_idx, batch in enumerate(map_loader):
utt = batch

if name == 'main':
get_dataset()

@vadimkantorov
Copy link
Author

vadimkantorov commented Jul 15, 2021

Hmm, not sure, it should normally be not leaking with this usage. Unfortunately, I wouldn't have the time to debug, you'd need to look into it yourself if you wish to discover the problem's source. Try modifying StringArray to use numpy arrays instead of PyTorch array.

Function I used to measure this leak for all data loader threads:

import psutil

def compute_ram_memory_stats(byte_scaler =1024 ** 3):
	stats = {}
	process = psutil.Process()
	children = process.children(recursive=True)
	total_pss_ram = process.memory_full_info().pss + sum(
		child.memory_full_info().pss for child in children
	)
	stats['pss_ram'] = total_pss_ram / byte_scaler
	return stats

@alsm168
Copy link

alsm168 commented Jul 15, 2021

not sure, it should normally be not leaking with this usage. Unfortunately, I wouldn't have the time to debug, you'd need to look int

thank you,I'll try to find the problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment