Last active
July 12, 2023 21:17
-
-
Save lgray/68ef326d1f0ac59f44147fbda1cda3cc to your computer and use it in GitHub Desktop.
data batcher for smartpixels samples
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CustomDataGenerator(tf.keras.utils.Sequence): | |
def __init__(self, | |
data_directory_path: str = "./", | |
labels_directory_path: str = "./", | |
is_directory_recursive: bool = False, | |
file_type: str = "csv", | |
data_format: str = "2D", | |
batch_size: int = 32, | |
file_count = None, | |
labels_list: Union[List,str] = "cotAlpha", | |
to_standardize: bool = False, | |
input_shape: Tuple = (13,21), | |
transpose: Tuple | None = None, | |
include_y_local: bool = False, | |
files_from_end = False, | |
**kwargs, | |
): | |
""" | |
Data Generator to streamline data input to the network direct from the directory. | |
Args: | |
data_directory_path: | |
labels_directory_path: | |
is_directory_recursive: | |
file_type: Default: "csv" | |
Adapt the data loader according to file type. For now, it only supports csv and parquet file formats. | |
data_format: Default: 2D | |
Used to refer to the relevant "recon" files, 2D for 2D pixel array, 3D for time series input, | |
batch_size: Default: 32 | |
The no. of data points to be included in a single batch. | |
file_count: Default: None | |
To limit the no. of .csv files to be used for training. | |
If set to None, all files will be considered as legitimate inputs. | |
labels_list: Default: "cotAlpha" | |
Input column name or list of column names to be used as label input to the neural network. | |
to_standardize: If set to True, it ensures that batches are normalized prior to being used as inputs | |
for training. | |
Default: False | |
input_shape: Default: (13,21) for image input to a 2D feedforward neural network. | |
To reshape the input array per the requirements of the network training. | |
""" | |
if file_type not in ["csv", "parquet"]: | |
raise ValueError("file_type can only be \"csv\" or \"parquet\"!") | |
self.file_type = file_type | |
self.recon_files = glob.glob( | |
data_directory_path + "recon" + data_format + "*." + file_type, | |
recursive=is_directory_recursive | |
) | |
self.recon_files.sort() | |
if file_count != None: | |
if not files_from_end: | |
self.recon_files = self.recon_files[:file_count] | |
else: | |
self.recon_files = self.recon_files[-file_count:] | |
self.label_files = [ | |
labels_directory_path + recon_file.split('/')[-1].replace("recon" + data_format, "labels") for recon_file in self.recon_files | |
] | |
self.file_offsets = [0] | |
self.dataset_mean = None | |
self.dataset_std = None | |
for afile in self.recon_files: | |
if self.file_type == "csv": | |
adf = pd.read_csv(afile).dropna() | |
elif self.file_type == "parquet": | |
adf = pd.read_parquet(afile).dropna() | |
self.file_offsets.append(self.file_offsets[-1] + len(adf)) | |
x = adf.values.reshape((-1, *input_shape)) | |
if transpose is not None: | |
x = x.transpose(transpose) | |
amean, avariance = self.get_mean_and_variance(x) | |
if self.dataset_mean is None: | |
self.dataset_mean = amean | |
self.dataset_std = avariance | |
else: | |
self.dataset_mean += amean | |
self.dataset_std += avariance | |
self.dataset_mean = self.dataset_mean / len(self.recon_files) | |
self.dataset_std = np.sqrt(self.dataset_std / len(self.recon_files)) | |
self.file_offsets = np.array(self.file_offsets) | |
self.batch_size = batch_size | |
self.labels_list = labels_list | |
self.input_shape = input_shape | |
self.transpose = transpose | |
self.to_standardize = to_standardize | |
self.include_y_local = include_y_local | |
self.current_file_index = None | |
self.current_dataframes = None | |
self.on_epoch_end() | |
def get_mean_and_variance(self, x): | |
"""Applies the normalization configuration in-place to a batch of | |
inputs. | |
`x` is changed in-place since the function is mainly used internally | |
to standardize images and feed them to your network. | |
Args: | |
x: Batch of inputs to be normalized. | |
Returns: | |
The inputs, normalized. | |
""" | |
return np.mean(x, axis=0, keepdims=True), np.var(x, axis=0, keepdims=True) + 1e-10 | |
def standardize(self, x): | |
"""Applies the normalization configuration in-place to a batch of | |
inputs. | |
`x` is changed in-place since the function is mainly used internally | |
to standardize images and feed them to your network. | |
Args: | |
x: Batch of inputs to be normalized. | |
Returns: | |
The inputs, normalized. | |
""" | |
return (x - self.dataset_mean)/self.dataset_std | |
def on_epoch_end(self): | |
""" | |
Inherited from the parent class. | |
Used to reset indices but not of significance in this datagenerator. | |
""" | |
pass | |
def __getitem__(self, batch_index): | |
""" | |
Used to fetch a batch of inputs (X,y) for the network's training. | |
""" | |
index = batch_index * self.batch_size # absolute *event* index | |
file_index = np.arange(self.file_offsets.size)[index < self.file_offsets][0] - 1 # first index is 0! | |
index = index - self.file_offsets[file_index] # relative event index in file | |
batch_size = min(index + self.batch_size, self.file_offsets[file_index + 1] - self.file_offsets[file_index]) | |
if file_index != self.current_file_index: | |
self.current_file_index = file_index | |
if self.file_type == "csv": | |
recon_df = pd.read_csv(self.recon_files[file_index]) | |
labels_df = pd.read_csv(self.label_files[file_index])[self.labels_list] | |
elif self.file_type == "parquet": | |
recon_df = pd.read_parquet(self.recon_files[file_index]) | |
labels_df = pd.read_parquet(self.label_files[file_index], columns=self.labels_list) | |
has_nans = np.any(np.isnan(recon_df.values), axis=1) | |
has_nans = np.arange(recon_df.shape[0])[has_nans] | |
recon_df_raw = recon_df.drop(has_nans) | |
labels_df_raw = labels_df.drop(has_nans) | |
joined_df = recon_df_raw.join(labels_df_raw).sample(frac=1).reset_index(drop=True) | |
self.current_dataframes = (joined_df[recon_df_raw.columns], joined_df[labels_df_raw.columns]) | |
del joined_df | |
del recon_df | |
del labels_df | |
del recon_df_raw | |
del labels_df_raw | |
recon_df, labels_df = self.current_dataframes | |
X = recon_df.iloc[index:batch_size].values.reshape((-1, *self.input_shape)) | |
y = labels_df.iloc[index:batch_size].values | |
if self.transpose is not None: | |
X = X.transpose(self.transpose) | |
if self.to_standardize: | |
X = self.standardize(X) | |
if self.include_y_local: | |
y_local = labels_df.iloc[chosen_idxs]["y-local"].values | |
return [X, y_local], y | |
else: | |
return X, y | |
def __len__(self): | |
return self.file_offsets[-1] // self.batch_size |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# set batch sizes and number of input files for train and test like this | |
batch_size = 100 | |
val_batch_size = 100 | |
train_file_size = 142 | |
val_file_size = 16 | |
# make your data generators (example for fully 3D data) | |
training_generator = CustomDataGenerator( | |
data_directory_path = "/Users/lgray/smart-pixels/dataset8/unflipped-positive/data/", | |
labels_directory_path = "/Users/lgray/smart-pixels/dataset8/unflipped-positive/labels/", | |
is_directory_recursive = False, | |
file_type = "parquet", | |
data_format = "3D", | |
batch_size = batch_size, | |
file_count = train_file_size, | |
to_standardize= True, | |
include_y_local= False, | |
labels_list = ['x-midplane','y-midplane','cotAlpha','cotBeta'], | |
input_shape = (20,13,21), | |
transpose = (0,2,3,1), | |
) | |
validation_generator = CustomDataGenerator( | |
data_directory_path = "/Users/lgray/smart-pixels/dataset8/unflipped-positive/data/", | |
labels_directory_path = "/Users/lgray/smart-pixels/dataset8/unflipped-positive/labels/", | |
is_directory_recursive = False, | |
file_type = "parquet", | |
data_format = "3D", | |
batch_size = val_batch_size, | |
file_count = val_file_size, | |
to_standardize= True, | |
include_y_local= False, | |
labels_list = ['x-midplane','y-midplane','cotAlpha','cotBeta'], | |
input_shape = (20,13,21), | |
transpose = (0,2,3,1), | |
files_from_end=True, # take the files from the *end* of the directory so we don't overlap statistics | |
) | |
# finally later you put it into model like fit like this | |
# training | |
es = EarlyStopping( | |
patience=20, | |
restore_best_weights=True | |
) | |
history = model.fit(x=training_generator, | |
validation_data=validation_generator, | |
callbacks=[es], | |
epochs=1000, | |
shuffle=False, # important!!!!! shuffling now occurs within the data-loader, do not turn on here | |
verbose=1 | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment