Skip to content

Instantly share code, notes, and snippets.

@JakeColor
Created August 27, 2021 21:12
Show Gist options
  • Save JakeColor/b36cefea313852f090b4381e5de02e29 to your computer and use it in GitHub Desktop.
Save JakeColor/b36cefea313852f090b4381e5de02e29 to your computer and use it in GitHub Desktop.
SimpleNumpyDataset
import logging
import os
import random
from itertools import cycle
import torch
import pandas as pd
import numpy as np
from types import FunctionType
from multiprocessing import Queue, Process
logger = logging.getLogger('rf_train.datasets')
class SimpleNumpyDataset(torch.utils.data.Dataset):
"""
Dataset where 1 file = 1 batch.
"""
def __init__(self, dataDir, end_of_input_cols_index, testset=False, cache_enabled=False):
"""
:param end_of_input_cols_index: the last column that contains input data
"""
super().__init__()
self.dataDir = dataDir
self.end_of_input_cols_index = end_of_input_cols_index
self.numpy_names = [filename for filename in os.listdir(dataDir) if filename.endswith('.npy')]
assert len(self.numpy_names) > 0, "Empty Directory"
# Assume that numpy size is the same in all files in the directory
self.numpy_size = int(self.numpy_names[0].split('-')[0])
# Length = # number of numpys, NOT number of points. This effects iterations / Epochs
self._len = len(self.numpy_names)
self.testset = testset
# TODO: Map Style dataset supports epochs, but dataloader with epochs is breaking the caching. Setup shared memory with multiprocessing lib to fix.
self.cache_enabled = cache_enabled
self.cache = {"empty": 0}
def get_data(self, numpy_name, is_sample=False):
"""Returns the dataframe from cache if possible, from filesystem if not cached"""
clean_numpy_name = numpy_name.split(".")[0]
# print(f"Loading numpy: {clean_numpy_name} ")
if clean_numpy_name in self.cache:
return self.cache[clean_numpy_name]
else:
try:
file_path = os.path.join(self.dataDir, numpy_name)
with open(file_path, 'rb') as f:
arr = np.load(f)
if self.cache_enabled:
self.cache[clean_numpy_name] = arr
except OSError as err:
logger.error(f"Error on loading or processing: {numpy_name}")
logger.error("OS error: {0}".format(err))
return arr
def __len__(self):
return len(self.numpy_names)
def __getitem__(self, idx):
arr = self.get_data(self.numpy_names[idx])
# casting to np.float64 is a hack for faster training speed
# for more, see https://github.com/riskfuel/riskfuel/issues/2356
arr_input = arr[:,:(self.end_of_input_cols_index + 1)].astype(np.float64)
arr_output = arr[:,(self.end_of_input_cols_index + 1):].astype(np.float64)
return torch.Tensor(arr_input), torch.Tensor(arr_output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment