-
-
Save zeneofa/d8fc2063e53591630c676bc9390d0ce4 to your computer and use it in GitHub Desktop.
DataLoader for large sequence data
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import Dataset, DataLoader | |
import linecache | |
import numpy as np | |
class InteractionData(Dataset): | |
""" Interacts with the interaction data set | |
:param inputfile: file containing four columns, no header. The first column has the number of lines of the file whos path is given in the second column, the third column is the cumulative sum of the first. | |
The final column is just an 0-based index (. | |
:param property file: file containing the properties of each character. Tab seperated. | |
:param seq_length: length of the string in the inputfile. | |
:param negative_prob: the negative samples will be from a random uniform background, this argument determines the proportion of samples that should be negative. | |
""" | |
def __init__(self,inputfile, property_file, seq_length, negative_prob=0.5): | |
self.inputfile = inputfile | |
dd = pd.read_csv(inputfile,sep="\t",header=None,names=['n_lines','filename','n_lines_cumulative','IDX']) | |
self.dd = dd | |
properties = {} | |
self.num_features = 0 | |
with open(property_file,'r') as fh1: | |
fh1.readline() | |
for line in fh1: | |
linesplit = line.rstrip().split() | |
prop_code = linesplit.pop(0) | |
properties[prop_code] = np.asarray([float(ii) for ii in linesplit]) | |
self.num_features = len(linesplit) | |
self.properties = properties | |
self.seq_length = seq_length | |
self.negative_prob = negative_prob | |
self._total_size = dd['n_lines_cumulative'].values[-1] #this is big: 18277900320 | |
assert self.num_features > 0, 'No features in the property file' | |
def __len__(self): | |
return self._total_size | |
def __getitem__(self,idx): | |
""" Get a data set entry, should return a x,y pair | |
:param idx: the index to extract | |
returns (torch.FloatTensor(x),torch.LongTensor(y)) | |
""" | |
targetFile,realLineNumber = self.checkLine(idx) | |
if np.random.uniform(0,1) < self.negative_prob: | |
#generate a negative sample | |
y = torch.LongTensor([0]) | |
possible_char = self.properties.keys() | |
x = torch.zeros((self.seq_length,self.seq_length,self.num_features)).float() | |
for ii,aa in enumerate(np.random.choice(possible_char,self.seq_length)): | |
for jj,bb in enumerate(np.random.choice(possible_char,self.seq_length)): | |
value = self.properties[aa] * self.properties[bb] | |
x[ii,jj,:] = torch.from_numpy(value) | |
else: | |
y = 1 | |
line = linecache.getline(targetFile, realLineNumber + 1).rstrip().split() #the +1 is needed to avoid the header | |
seq_1 = line[0] | |
seq_2 = line[1] | |
x = torch.from_numpy(transform(seq_1,seq_2,self.properties)) #x is a [seq_length,seq_length,value] numpy array | |
y = torch.LongTensor([1]) | |
return x,y | |
def checkLine(self,idx,hasHeaders=True): | |
dd = self.dd | |
fileNumber = dd.loc[idx < dd['n_lines_cumulative'] - 1 - dd['IDX'],'IDX'].min() | |
fileName = dd['filename'].values[fileNumber] | |
lineNumber = idx | |
if fileNumber == 0: | |
return fileName, lineNumber + fileNumber +1 | |
else: | |
lineNumber = lineNumber - dd['n_lines_cumulative'].values[fileNumber-1] | |
return fileName, lineNumber + fileNumber + 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment