Skip to content

Instantly share code, notes, and snippets.

@seanie12
Last active February 8, 2020 06:28
Show Gist options
  • Save seanie12/52a45202d862c7b612d1b83ec242cc64 to your computer and use it in GitHub Desktop.
Save seanie12/52a45202d862c7b612d1b83ec242cc64 to your computer and use it in GitHub Desktop.
import linecache
import subprocess
import torch
from torch.utils.data import Dataset, DataLoader
class EEGDataset(Dataset):
def __init__(self, filename):
self.filename = filename
self.total_size = int(subprocess.check_output("wc -l " + filename, shell=True).split()[0])
def __getitem__(self, idx):
line = linecache.getline(self.filename, idx + 1)
str_x, str_y = line.split("\t")
str_x = str_x.split("")
y = int(str_y)
x = torch.tensor([float(x) for x in str_x],
dtype=torch.float)
x = x.view(2500, 2, 3)
return x, y
def __len__(self):
return self.total_size
if __name__ == "__main__":
dataset = EEGDataset("./data/eeg.txt")
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
for batch in dataloader:
x, y = batch
print(x.size(), y.size())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment