Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active December 7, 2023 23:43
Show Gist options
  • Save vadimkantorov/86c3a46bf25bed3ad45d043ae86fff57 to your computer and use it in GitHub Desktop.
Save vadimkantorov/86c3a46bf25bed3ad45d043ae86fff57 to your computer and use it in GitHub Desktop.
Tensor-backed immutable string array and list-of-dicts to be used in PyTorch Dataset classes to work around copied shared memory-pages when using Python lists of strings https://github.com/pytorch/pytorch/issues/13246
import math
import typing
import torch
class StringArray:
def __init__(self, strings : typing.List[str], encoding : typing.Literal['ascii', 'utf_16_le', 'utf_32_le'] = 'utf_16_le'):
strings = list(strings)
self.encoding = encoding
self.multiplier = dict(ascii = 1, utf_16_le = 2, utf_32_le = 4)[encoding]
self.data = torch.ByteTensor(torch.ByteStorage.from_buffer(''.join(strings).encode(encoding)))
self.cumlen = torch.LongTensor(list(map(len, strings))).cumsum(dim = 0).mul_(self.multiplier)
assert int(self.cumlen[-1]) == len(self.data), f'[{encoding}] is not enough to hold characters, use a larger character class'
def tolist(self):
data_bytes, cumlen = bytes(self.data), self.cumlen.tolist()
return [data_bytes[0:cumlen[0]].decode(self.encoding)] + [data_bytes[start:end].decode(self.encoding) for start, end in zip(cumlen[:-1], cumlen[1:])]
def __getitem__(self, i):
return bytes(self.data[(self.cumlen[i - 1] if i >= 1 else 0) : self.cumlen[i]]).decode(self.encoding)
def __len__(self):
return len(self.cumlen)
class DictArray:
def __init__(self, dicts : typing.List[dict], types : typing.Dict[str, typing.ClassVar] = {}, *, batch_size : int = 1024, string_encoding : typing.Literal['ascii', 'utf_16_le', 'utf_32_le'] = 'utf_16_le', ints_dtype = torch.int64):
dicts = list(dicts)
numel = len(dicts)
assert numel > 0
self.tensors = {k : t(numel) for k, t in types.items() if t != StringArray and t != IntsArray}
string_lists = {k : [None] * numel for k, t in types.items() if t == StringArray}
ints_lists = {k : [None] * numel for k, t in types.items() if t == IntsArray}
temp_lists = {k : [None] * batch_size for k in self.tensors}
for b in range(math.ceil(numel / batch_size)):
for i, t in enumerate(dicts[b * batch_size : (b + 1) * batch_size]):
for k in temp_lists:
temp_lists[k][i] = t[k]
for k in string_lists:
string_lists[k][b * batch_size + i] = t[k]
for k in ints_lists:
ints_lists[k][b * batch_size + i] = t[k]
for k, v in temp_lists.items():
res = self.tensors[k][b * batch_size : (b + 1) * batch_size]
res.copy_(torch.as_tensor(v[:len(res)], dtype = self.tensors[k].dtype))
self.string_arrays = {k : StringArray(v, encoding = string_encoding) for k, v in string_lists.items()}
self.ints_arrays = {k : IntsArray(v, dtype = ints_dtype) for k, v in ints_lists.items()}
def __getitem__(self, i):
return dict(**{k : v[i].item() for k, v in self.tensors.items()}, **{k : v[i] for k, v in self.string_arrays.items()}, **{k : v[i] for k, v in self.ints_arrays.items()})
def __len__(self):
return len(next(iter(self.tensors.values()))) if len(self.tensors) > 0 else len(next(iter(self.string_arrays.values())))
class NamedTupleArray(DictArray):
def __init__(self, namedtuples, *args, **kwargs):
super().__init__([t._asdict() for t in namedtuples], *args, **kwargs)
self.namedtuple = type(next(iter(namedtuples)))
def __getitem__(self, index):
return self.namedtuple(**super().__getitem__(index))
class IntsArray:
def __init__(self, ints, dtype = torch.int64):
tensors = [torch.as_tensor(t, dtype = dtype) for t in ints]
self.data = torch.cat(tensors)
self.cumlen = torch.tensor(list(map(len, tensors)), dtype = torch.int64).cumsum(dim = 0)
def __getitem__(self, i):
return self.data[(self.cumlen[i - 1] if i >= 1 else 0) : self.cumlen[i]]
def __len__(self):
return len(self.cumlen)
if __name__ == '__main__':
a = StringArray(['asd', 'def'])
print('len = ', len(a))
print('data = ', list(a))
a = DictArray([dict(a = 1, b = 'def'), dict(a = 2, b = 'klm')], types = dict(a = torch.LongTensor, b = StringArray))
print('len = ', len(a))
print('data = ', list(a))
@jinmingyi1998
Copy link

can this work for dict in dict in dict .... ?

@vadimkantorov
Copy link
Author

Nope, it doesn't support this mode

@akashrajkn
Copy link

akashrajkn commented Nov 18, 2020

It throws the following error:

len =  2
data =  ['asd', 'def']
Traceback (most recent call last):
  File "dictarray.py", line 52, in <module>
    a = DictArray([dict(a = 1, b = 'def'), dict(a = 2, b = 'klm')], types = dict(a = torch.LongTensor, b = StringArray))
  File "dictarray.py", line 17, in __init__
    string_lists.append(t[k])
AttributeError: 'dict' object has no attribute 'append'

@vadimkantorov
Copy link
Author

I fixed it to be string_lists[k][i] = t[k]. Please let me know if it works now. StringArray is tested, DictArray is unfortunately not tested

@akashrajkn
Copy link

akashrajkn commented Nov 18, 2020

Awesome! It works.
Thanks for the super quick reply 👍

@vadimkantorov
Copy link
Author

I did one more bug fix.

Let me know if it solves your memory leaking problems. The code can be made simpler without batching, but it would require several loops over the big array which can be slow if there're several fields (but I haven't profiled it).

@akashrajkn
Copy link

Thank you for the fix.
It doesn't solve my memory leaking issue - probably something else is going wrong in my Dataset.

@vadimkantorov
Copy link
Author

A good thing to measure is "pss", you can get these numbers from psutil. After packing, the DictArray holds only tensors and objects holding tensors, so it should share well

But at least StringArray worked well for my usecase

@ssss1029
Copy link

ssss1029 commented Nov 18, 2020

Hi, thanks for this fix. I tried out StringArray, but I observed higher CPU usage during __getitem__ compared to using a simple Python list. Is this expected?

@vadimkantorov
Copy link
Author

vadimkantorov commented Nov 19, 2020

As you can see, it slices a torch tensor, calls bytes (not sure if it copies or not) and then constructs a string - all of these are more costly than just list indexing. But it seems to not cause copy-on-write.

If you know that your strings are ascii, you can set encoding = 'ascii'

@akashrajkn
Copy link

@vadimkantorov, I solved the memory issue in my code (the issue was elsewhere). I did have time to try out the DictArray (DictArray wth StringArrays, DictArray with DictArrays) with the multiprocessing module and there were no memory leaks.

I really appreciate that you replied with the fixes so soon, thanks!

@Erotemic
Copy link

There is a missing ) on the res.copy_ line.

@vadimkantorov
Copy link
Author

Fixed! Thanks!

@dmus
Copy link

dmus commented Dec 25, 2020

I am facing the same problem: RuntimeError: unable to open shared memory object </torch_37358_1335134282> in read-write mode

I see some solutions when the dataset is a list of strings of a list of dicts, but is there also a workaround for when I have a list of a custom objects?

@vadimkantorov
Copy link
Author

You'd have to do something custom. But the general approach "manually pack/unpack all data to tensors" would still work.

@dmus
Copy link

dmus commented Dec 28, 2020

I get the follwing error when trying this script:

Traceback (most recent call last):
  File "tensorbackeddictarray.py", line 6, in <module>
    class DictArray:
  File "tensorbackeddictarray.py", line 10, in DictArray
    ints_dtype=torch.int64):
  File "/usr/lib/python3.6/typing.py", line 682, in inner
    return func(*args, **kwds)
  File "/usr/lib/python3.6/typing.py", line 1107, in __getitem__
    params = tuple(_type_check(p, msg) for p in params)
  File "/usr/lib/python3.6/typing.py", line 1107, in <genexpr>
    params = tuple(_type_check(p, msg) for p in params)
  File "/usr/lib/python3.6/typing.py", line 374, in _type_check
    raise TypeError(msg + " Got %.100r." % (arg,))
TypeError: Parameters to generic types must be types. Got typing.ClassVar.

Is this because of the Python version?

@vadimkantorov
Copy link
Author

vadimkantorov commented Dec 28, 2020 via email

@rgtjf
Copy link

rgtjf commented Jul 10, 2021

Nice code! Thanks for providing these!

I found the code is extremely slow to run this:

# coding=utf-8

import torch
import logging
import json

logger = logging.getLogger(__name__)

class StringArray:
  def __init__(self, strings,
               encoding= 'utf_16_le'):
    strings = list(strings)
    self.encoding = encoding
    self.multiplier = dict(ascii=1, utf_16_le=2, utf_32_le=4)[encoding]
    self.data = torch.ByteTensor(torch.ByteStorage.from_buffer(''.join(strings).encode(encoding)))
    self.cumlen = torch.LongTensor(list(map(len, strings))).cumsum(dim=0).mul_(self.multiplier)
    assert int(self.cumlen[-1]) == len(
      self.data), f'[{encoding}] is not enough to hold characters, use a larger character class'

  def __getitem__(self, i):
    return bytes(self.data[(self.cumlen[i - 1] if i >= 1 else 0): self.cumlen[i]]).decode(self.encoding)

  def __len__(self):
    return len(self.cumlen)

  def tolist(self):
    data_bytes, cumlen = bytes(self.data), self.cumlen.tolist()
    return [data_bytes[0:cumlen[0]].decode(self.encoding)] \
           + [data_bytes[start:end].decode(self.encoding)
              for start, end in zip(cumlen[:-1], cumlen[1:])]


if __name__ == '__main__':
  data = list(map(str, range(30000000)))
  data = [json.dumps({datum:datum*100, 'value': int(datum)}) for datum in data]

  logger.info('load data')
  for i in range(len(data)):
    try:
      datum = json.loads(data[i])
    except:
      logger.info('error, {}, {}'.format(i, data[i]))
    if i < 10 or i % 100000 == 0:
      logger.info('idx={}, {}'.format(i, 'log'))

Could someone provide any solutions?

@alsm168
Copy link

alsm168 commented Jul 12, 2021

Hi,how to save a dict with tensor array?

@vadimkantorov
Copy link
Author

vadimkantorov commented Jul 12, 2021

@rgtjf Your code doesn't use StringArray at all. So I can't help. That being said, accessing StringArray in a loop is going to be slow, and for this usage you need to call tolist() first and then consume its output. It's being slow because tensor item access is terribly slow in PyTorch. What fixes this issue is implementing a tolist() method that calls tolist() on internal tensors and then populates the resulting dict array (Like done in StringArray). If you need fast access to a range of elements, please use tolist() or implement similar helpers that would first call tolist() on the internal tensors.

@alsm168 Could you rephrase your question? If you want to serialise it, you'd need to serialise the internal tensor arrays such as self.cumlen and self.data

@alsm168
Copy link

alsm168 commented Jul 15, 2021

@vadimkantorov I use the StringArray as the below,but I found the memory usage is growing slowly,Is there anything wrong?thanks

import torch
from tensorbackeddictarray import StringArray

class SyllableDataset(torch.utils.data.Dataset):
def init(self, data_list):
self.data_list = data_list

def __len__(self):
    return len(self.data_list)

def __getitem__(self, idx):
    utt, wav_path, syllable_label = self.data_list[idx].split('\t')
    return utt

def get_dataset():
map_file = r'big_syllable_data.data'
map_data_list = []
with open(map_file, 'r', encoding='utf-8') as fid:
for line in fid:
utt, wav_path, syllable_label = line.strip().split('\t')
map_data_list.append(f'{utt}\t{wav_path}\t{syllable_label}')
map_data_list = StringArray(map_data_list)
mapdataset = SyllableDataset(map_data_list)
map_loader = torch.utils.data.DataLoader(mapdataset,
batch_size=128,
num_workers=4,
pin_memory=False,
shuffle=False)
for batch_idx, batch in enumerate(map_loader):
utt = batch

if name == 'main':
get_dataset()

@vadimkantorov
Copy link
Author

vadimkantorov commented Jul 15, 2021

Hmm, not sure, it should normally be not leaking with this usage. Unfortunately, I wouldn't have the time to debug, you'd need to look into it yourself if you wish to discover the problem's source. Try modifying StringArray to use numpy arrays instead of PyTorch array.

Function I used to measure this leak for all data loader threads:

import psutil

def compute_ram_memory_stats(byte_scaler =1024 ** 3):
	stats = {}
	process = psutil.Process()
	children = process.children(recursive=True)
	total_pss_ram = process.memory_full_info().pss + sum(
		child.memory_full_info().pss for child in children
	)
	stats['pss_ram'] = total_pss_ram / byte_scaler
	return stats

@alsm168
Copy link

alsm168 commented Jul 15, 2021

not sure, it should normally be not leaking with this usage. Unfortunately, I wouldn't have the time to debug, you'd need to look int

thank you,I'll try to find the problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment