Skip to content

Instantly share code, notes, and snippets.

@AdityaSoni19031997
Created April 21, 2020 02:19
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save AdityaSoni19031997/e001ab8a464bfb0dec6a5a70d19fca16 to your computer and use it in GitHub Desktop.
Save AdityaSoni19031997/e001ab8a464bfb0dec6a5a70d19fca16 to your computer and use it in GitHub Desktop.
In this gist i have tried to explain a very smart way of loading datasets by streaming them from bytes into PyTorch; It can be achieved in multiple ways, but here my focus was confined to David's idea of streaming records from a bytes file;
import torch
import io
import pandas as pd
import gc
import numpy as np
import transformers
'''
Original Code Author [@dlibenzi](https://github.com/dlibenzi)
Complete Colab Example [link](https://colab.research.google.com/drive/1IvCxIg-Q_DlI7UNJuajpl4UZXNiW5jMg)
Feel free to explore this [issue](https://github.com/pytorch/xla/issues/1870)
NB I have trimmed to what i felt was new to me as the idea, please refer to the colab
for the rest of the code blocks..
The code is tricky to understand if you have NEVER worked with Files directly before;
Here on top of file handling concepts, it's in binary format, so you need to be little more careful
with regards to writing and loading data.. (as endianess matters, bow bytes are stored in CPU etc,
So in short, it's relating multiple concepts together!
Personal Notes
----------------
##############################################################
###### PLEASE CONSIDER READING THIS BEFORE EXEC THE CODE ####
###### RUNNING A CODE WITHOUT UNDERSTANDING IT IS USELESS ####
##############################################################
# A binary file is considered to be just a sequence of bytes - none of them has any special meaning,
# in the sense that a text-reader would interpret them..
# Basically binary files contain data and each individual byte can be an ascii character, an integer or a tensor etc.
# It's just how to write data to the file and how you read it back, determines everything;
# In Python, The io.BytesIO inherits from io.BufferedReader class adn thus comes with functions like read(), write(), peek(), getvalue().
# Simply put, io.BytesIO is a general buffer of bytes that you can work with..
# Also i hpe you are aware of the fact that binary data and strings are different types,
# so a str must be encoded to binary using ascii, utf-8, or other.
# The getvalue() function just takes the value from the Buffer as a String
# and return bytes containing the entire contents of the buffer.
# Seeking a specific position in a file
You can move to a specific position in file before reading or writing using seek().
You can pass a single parameter to seek() and it will move to that position, relative to the beginning of the file.
# Seek can be called one of two ways:
# x.seek(offset)
# x.seek(offset, starting_point)
# The offset is interpreted relative to the position indicated by whence
# starting_point can be 0, 1, or 2
# 0 - Default. Offset relative to beginning of file
# 1 - Start from the current position in the file
# 2 - Start from the end of a file (will require a negative offset)
# Example Illustrating Conversion Of An Int to Bytes
# '\xhh' is an escape sequence that describes the byte with that hexadecimal value.
# b'\x00\xff' -> Two bytes values 0 and 255
i = 16
i.to_bytes(1, byteorder='big', signed=True) # b'\x10'
i.to_bytes(4, byteorder='big', signed=True) # b'\x00\x00\x00\x10'
i.to_bytes(4, byteorder='little', signed=True) # b'\x10\x00\x00\x00'
# From https://docs.python.org/3/library/stdtypes.html#int.from_bytes,
If byteorder is "big", the most significant byte is at the beginning of the byte array.
If byteorder is "little", the most significant byte is at the end of the byte array.
# int.from_bytes(b'\x00\x10', byteorder='big') # 16
# int.from_bytes(b'\x00\x10', byteorder='little') # 4096
How can I read last 10 bytes from a text file ?? [solve it yourself first]
# here's my sol
f.seek(0, 2) # last byte
nbytes = f.tell()
f.seek(nbytes-10)
last_ten = f.read(10)
# f.tell() Returns the current stream position.
# f.read(size=k) Read and return up to size bytes
# Saving Tensors to a file
x = torch.tensor([0, 1, 2, 3, 4])
torch.save(x, 'tensor.pt')
# Saving Tensors to io.BytesIO buffer
buffer = io.BytesIO()
torch.save(x, buffer)
'''
model="xlm-roberta-large"
batch_size=2
splits="8,16,32,64"
train_ds="train_dataset"
valid_ds="valid_dataset"
class FileDataset(object):
def __init__(self, path):
# open binary files to write to
self._data_file = open(path + '.data', 'rb')
self._index_file = open(path + '.index', 'rb')
self._index_file.seek(0, 2) # seek the last byte of the file
self._index_size = self._index_file.tell() # size of the current stream position, Basically Get the file size...
assert self._index_size % 8 == 0
self._data_file.seek(0, 2) # seek the last byte of the file
self._data_size = self._data_file.tell() # size of the current stream position, Basically Get the file size...
def read_sample(self, idx):
'''
The idea is basically that first you seek next 8 bytes from your current position (where ever you are in the file).
After that you check that whether we can get the next's next whole 8 bytes as well; (so it's current + 16).
If we can, then we calculate the next offset.
'''
index_offset = idx * 8
assert index_offset < self._index_size
self._index_file.seek(index_offset) # move to this position relative to the beginning of the file
data_offset = int.from_bytes(self._index_file.read(8), byteorder='little') # read eight_bytes in little endian byte order
if index_offset + 16 <= self._index_size:
next_offset = int.from_bytes(self._index_file.read(8), byteorder='little') # next 8 bytes set-up for next seek
else:
next_offset = self._data_size # else set it to end of the file
self._data_file.seek(data_offset) # move to this position relative to the beginning of the file wrt data file
sample_data = self._data_file.read(next_offset - data_offset) # read these many
return torch.load(io.BytesIO(sample_data)) # return as a tensor
def get_num_samples(self):
return self._index_size // 8
def bytes_from_file(fname, ck_sz=8192):
'''
simple func to stream bytes from the given file
'''
with open(fname, "rb") as f:
while True:
chunk = f.read(ck_sz)
if chunk:
for b in chunk:
yield b
else:
break
def regular_encode_on_fly(texts, tokenizer, splits):
'''
pad only to the length that's needed to make the batch padded to same length
aka bucketing
'''
max_len = max(len(x.split()) for x in texts)
for l in splits:
if l >= max_len:
max_len = l
break
max_len = min(max_len, splits[-1])
enc_di = tokenizer.batch_encode_plus(
texts,
add_special_tokens=True,
return_attention_masks=True,
return_token_type_ids=False,
pad_to_max_length=True,
max_length=max_len,
)
return np.array(enc_di['input_ids']), np.array(enc_di["attention_mask"])
def indices_for_ordinal(ordinal, world_size, count, shuffle=True):
'''
it's a 3 line sampler;
ordinal denotes TPU_IDX
world_size denotes how many TPU_CORES
count denotes the #samples you have in your dataset file
'''
count = (count // world_size) * world_size
indices = list(range(ordinal, count, world_size)) # start:end:step_size
if shuffle:
np.random.shuffle(indices)
return indices
def prep_data(bs, df):
'''
basically the idea is to create the batches ourselves;
NB we are using dynamic padding here (splits variable);
'''
sentences = df['comment_text'].astype(str).values
sort_idx = np.argsort(np.array([len(x.split()) for x in sentences]))
sentences = sentences[sort_idx]
targets = df['toxic'].values[sort_idx]
num_samples = (len(sentences) // bs) * bs
sentences = sentences[: num_samples]
targets = targets[: num_samples]
return sentences.reshape(len(sentences) // bs, bs), targets.reshape(len(targets) // bs, bs)
def write_sample(s, data_file, index_file):
bio = io.BytesIO() # get a buffer
torch.save(s, bio) # save it to the buffer
offset = data_file.tell() # [int] what's the current position of the "data_file" stream
index_file.write((offset).to_bytes(8, byteorder='little')) # write the index for this tensor batch
data_file.write(bio.getvalue())
def create_dataset(df, tokenizer, batch_size, splits, path):
x, y = prep_data(batch_size, df) # grab the batches (raw-text, targets)
xt = [torch.tensor(regular_encode_on_fly(t, tokenizer, splits)) for t in x] # converting them to tokens; each batch is dynamically padded
yt = [torch.tensor(t, dtype=torch.float) for t in y] # targets
with open(path + '.data', 'wb') as data_file:
with open(path + '.index', 'wb') as index_file:
for s in zip(xt, yt):
# since we are using zip, so it's packing the items from x and y respectively together;
write_sample(s, data_file, index_file)
def generate_index():
global splits
tokenizer = transformers.XLMRobertaTokenizer.from_pretrained(model)
train1 = pd.read_csv(
'./jigsaw-multilingual-toxic-comment-classification/jigsaw-toxic-comment-train.csv',
usecols=["comment_text", "toxic"],
nrows = 8
)
all_train = train1[['comment_text', 'toxic']]
del train1
gc.collect(); gc.collect();
all_train = all_train.sample((all_train.shape[0]//batch_size)*batch_size)
print('DF:', all_train.shape,)
splits = sorted([int(x) for x in splits.split(',')])
create_dataset(all_train, tokenizer, batch_size, splits, train_ds)
if __name__== "__main__":
generate_index()
@AdityaSoni19031997
Copy link
Author

Feel free to comment on if something seems incorrect or not explained,!

@AdityaSoni19031997
Copy link
Author

The *.index file will be having something similar to cumulative counting but not exactly that;
Like let's assume you have 0-99 from index 0 to index 45, 100-200 from index 46 - index 75..; so in the index file, you will store that your first data bunch was fro 0th to 45th, post that your next data bunch extends till 75, so you will store the difference between the two indexes (basically the offset). Something like below,

0
45  (0+45 = 45)
30  (45+30=75)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment