Last active
July 23, 2024 14:50
-
-
Save vadimkantorov/86c3a46bf25bed3ad45d043ae86fff57 to your computer and use it in GitHub Desktop.
Tensor-backed immutable string array and list-of-dicts to be used in PyTorch Dataset classes to work around copied shared memory-pages when using Python lists of strings https://github.com/pytorch/pytorch/issues/13246
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import typing | |
import torch | |
class StringArray: | |
def __init__(self, strings : typing.List[str], encoding : typing.Literal['ascii', 'utf_16_le', 'utf_32_le'] = 'utf_16_le'): | |
strings = list(strings) | |
self.encoding = encoding | |
self.multiplier = dict(ascii = 1, utf_16_le = 2, utf_32_le = 4)[encoding] | |
self.data = torch.ByteTensor(torch.ByteStorage.from_buffer(''.join(strings).encode(encoding))) | |
self.cumlen = torch.LongTensor(list(map(len, strings))).cumsum(dim = 0).mul_(self.multiplier) | |
assert int(self.cumlen[-1]) == len(self.data), f'[{encoding}] is not enough to hold characters, use a larger character class' | |
def tolist(self): | |
data_bytes, cumlen = bytes(self.data), self.cumlen.tolist() | |
return [data_bytes[0:cumlen[0]].decode(self.encoding)] + [data_bytes[start:end].decode(self.encoding) for start, end in zip(cumlen[:-1], cumlen[1:])] | |
def __getitem__(self, i): | |
return bytes(self.data[(self.cumlen[i - 1] if i >= 1 else 0) : self.cumlen[i]]).decode(self.encoding) | |
def __len__(self): | |
return len(self.cumlen) | |
class DictArray: | |
def __init__(self, dicts : typing.List[dict], types : typing.Dict[str, typing.ClassVar] = {}, *, batch_size : int = 1024, string_encoding : typing.Literal['ascii', 'utf_16_le', 'utf_32_le'] = 'utf_16_le', ints_dtype = torch.int64): | |
dicts = list(dicts) | |
numel = len(dicts) | |
assert numel > 0 | |
self.tensors = {k : t(numel) for k, t in types.items() if t != StringArray and t != IntsArray} | |
string_lists = {k : [None] * numel for k, t in types.items() if t == StringArray} | |
ints_lists = {k : [None] * numel for k, t in types.items() if t == IntsArray} | |
temp_lists = {k : [None] * batch_size for k in self.tensors} | |
for b in range(math.ceil(numel / batch_size)): | |
for i, t in enumerate(dicts[b * batch_size : (b + 1) * batch_size]): | |
for k in temp_lists: | |
temp_lists[k][i] = t[k] | |
for k in string_lists: | |
string_lists[k][b * batch_size + i] = t[k] | |
for k in ints_lists: | |
ints_lists[k][b * batch_size + i] = t[k] | |
for k, v in temp_lists.items(): | |
res = self.tensors[k][b * batch_size : (b + 1) * batch_size] | |
res.copy_(torch.as_tensor(v[:len(res)], dtype = self.tensors[k].dtype)) | |
self.string_arrays = {k : StringArray(v, encoding = string_encoding) for k, v in string_lists.items()} | |
self.ints_arrays = {k : IntsArray(v, dtype = ints_dtype) for k, v in ints_lists.items()} | |
def __getitem__(self, i): | |
return dict(**{k : v[i].item() for k, v in self.tensors.items()}, **{k : v[i] for k, v in self.string_arrays.items()}, **{k : v[i] for k, v in self.ints_arrays.items()}) | |
def __len__(self): | |
return len(next(iter(self.tensors.values()))) if len(self.tensors) > 0 else len(next(iter(self.string_arrays.values()))) | |
class NamedTupleArray(DictArray): | |
def __init__(self, namedtuples, *args, **kwargs): | |
super().__init__([t._asdict() for t in namedtuples], *args, **kwargs) | |
self.namedtuple = type(next(iter(namedtuples))) | |
def __getitem__(self, index): | |
return self.namedtuple(**super().__getitem__(index)) | |
class IntsArray: | |
def __init__(self, ints, dtype = torch.int64): | |
tensors = [torch.as_tensor(t, dtype = dtype) for t in ints] | |
self.data = torch.cat(tensors) | |
self.cumlen = torch.tensor(list(map(len, tensors)), dtype = torch.int64).cumsum(dim = 0) | |
def __getitem__(self, i): | |
return self.data[(self.cumlen[i - 1] if i >= 1 else 0) : self.cumlen[i]] | |
def __len__(self): | |
return len(self.cumlen) | |
if __name__ == '__main__': | |
a = StringArray(['asd', 'def']) | |
print('len = ', len(a)) | |
print('data = ', list(a)) | |
a = DictArray([dict(a = 1, b = 'def'), dict(a = 2, b = 'klm')], types = dict(a = torch.LongTensor, b = StringArray)) | |
print('len = ', len(a)) | |
print('data = ', list(a)) | |
Hmm, not sure, it should normally be not leaking with this usage. Unfortunately, I wouldn't have the time to debug, you'd need to look into it yourself if you wish to discover the problem's source. Try modifying StringArray to use numpy arrays instead of PyTorch array.
Function I used to measure this leak for all data loader threads:
import psutil
def compute_ram_memory_stats(byte_scaler =1024 ** 3):
stats = {}
process = psutil.Process()
children = process.children(recursive=True)
total_pss_ram = process.memory_full_info().pss + sum(
child.memory_full_info().pss for child in children
)
stats['pss_ram'] = total_pss_ram / byte_scaler
return stats
not sure, it should normally be not leaking with this usage. Unfortunately, I wouldn't have the time to debug, you'd need to look int
thank you,I'll try to find the problem.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@vadimkantorov I use the StringArray as the below,but I found the memory usage is growing slowly,Is there anything wrong?thanks
import torch
from tensorbackeddictarray import StringArray
class SyllableDataset(torch.utils.data.Dataset):
def init(self, data_list):
self.data_list = data_list
def get_dataset():
map_file = r'big_syllable_data.data'
map_data_list = []
with open(map_file, 'r', encoding='utf-8') as fid:
for line in fid:
utt, wav_path, syllable_label = line.strip().split('\t')
map_data_list.append(f'{utt}\t{wav_path}\t{syllable_label}')
map_data_list = StringArray(map_data_list)
mapdataset = SyllableDataset(map_data_list)
map_loader = torch.utils.data.DataLoader(mapdataset,
batch_size=128,
num_workers=4,
pin_memory=False,
shuffle=False)
for batch_idx, batch in enumerate(map_loader):
utt = batch
if name == 'main':
get_dataset()