Skip to content

Instantly share code, notes, and snippets.

@harsh-99
Last active May 6, 2021 13:46
Show Gist options
  • Save harsh-99/27e8c64ad387466c8fdaaacc51ccd5a4 to your computer and use it in GitHub Desktop.
Save harsh-99/27e8c64ad387466c8fdaaacc51ccd5a4 to your computer and use it in GitHub Desktop.
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import gensim
class Dataset_seq(Dataset):
def __init__(self, word2id, train_path):
self.word2id = word2id
self.train_path = train_path
# read the data and label
self.data, self.label = reader(train_path)
def __getitem__(self, index):
# return the seq and label
seq = self.preprocess(self.data[index])
label = self.label[index]
return seq, label
def __len__(self):
return(len(self.data))
def preprocess(self, text):
# used to convert line into tokens and then into their corresponding numericals values using word2id
line = gensim.utils.simple_preprocess(text)
seq = []
for word in line:
if word in self.word2id:
seq.append(self.word2id[word])
else:
seq.append(self.word2id['<unk>'])
#convert list into tensor
seq = torch.from_numpy(np.array(seq))
return seq
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment