Created
August 27, 2021 21:12
-
-
Save JakeColor/b36cefea313852f090b4381e5de02e29 to your computer and use it in GitHub Desktop.
SimpleNumpyDataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import os | |
import random | |
from itertools import cycle | |
import torch | |
import pandas as pd | |
import numpy as np | |
from types import FunctionType | |
from multiprocessing import Queue, Process | |
logger = logging.getLogger('rf_train.datasets') | |
class SimpleNumpyDataset(torch.utils.data.Dataset): | |
""" | |
Dataset where 1 file = 1 batch. | |
""" | |
def __init__(self, dataDir, end_of_input_cols_index, testset=False, cache_enabled=False): | |
""" | |
:param end_of_input_cols_index: the last column that contains input data | |
""" | |
super().__init__() | |
self.dataDir = dataDir | |
self.end_of_input_cols_index = end_of_input_cols_index | |
self.numpy_names = [filename for filename in os.listdir(dataDir) if filename.endswith('.npy')] | |
assert len(self.numpy_names) > 0, "Empty Directory" | |
# Assume that numpy size is the same in all files in the directory | |
self.numpy_size = int(self.numpy_names[0].split('-')[0]) | |
# Length = # number of numpys, NOT number of points. This effects iterations / Epochs | |
self._len = len(self.numpy_names) | |
self.testset = testset | |
# TODO: Map Style dataset supports epochs, but dataloader with epochs is breaking the caching. Setup shared memory with multiprocessing lib to fix. | |
self.cache_enabled = cache_enabled | |
self.cache = {"empty": 0} | |
def get_data(self, numpy_name, is_sample=False): | |
"""Returns the dataframe from cache if possible, from filesystem if not cached""" | |
clean_numpy_name = numpy_name.split(".")[0] | |
# print(f"Loading numpy: {clean_numpy_name} ") | |
if clean_numpy_name in self.cache: | |
return self.cache[clean_numpy_name] | |
else: | |
try: | |
file_path = os.path.join(self.dataDir, numpy_name) | |
with open(file_path, 'rb') as f: | |
arr = np.load(f) | |
if self.cache_enabled: | |
self.cache[clean_numpy_name] = arr | |
except OSError as err: | |
logger.error(f"Error on loading or processing: {numpy_name}") | |
logger.error("OS error: {0}".format(err)) | |
return arr | |
def __len__(self): | |
return len(self.numpy_names) | |
def __getitem__(self, idx): | |
arr = self.get_data(self.numpy_names[idx]) | |
# casting to np.float64 is a hack for faster training speed | |
# for more, see https://github.com/riskfuel/riskfuel/issues/2356 | |
arr_input = arr[:,:(self.end_of_input_cols_index + 1)].astype(np.float64) | |
arr_output = arr[:,(self.end_of_input_cols_index + 1):].astype(np.float64) | |
return torch.Tensor(arr_input), torch.Tensor(arr_output) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment