Skip to content

Instantly share code, notes, and snippets.

@lgray
Last active July 12, 2023 21:17
Show Gist options
  • Save lgray/68ef326d1f0ac59f44147fbda1cda3cc to your computer and use it in GitHub Desktop.
Save lgray/68ef326d1f0ac59f44147fbda1cda3cc to your computer and use it in GitHub Desktop.
data batcher for smartpixels samples
class CustomDataGenerator(tf.keras.utils.Sequence):
def __init__(self,
data_directory_path: str = "./",
labels_directory_path: str = "./",
is_directory_recursive: bool = False,
file_type: str = "csv",
data_format: str = "2D",
batch_size: int = 32,
file_count = None,
labels_list: Union[List,str] = "cotAlpha",
to_standardize: bool = False,
input_shape: Tuple = (13,21),
transpose: Tuple | None = None,
include_y_local: bool = False,
files_from_end = False,
**kwargs,
):
"""
Data Generator to streamline data input to the network direct from the directory.
Args:
data_directory_path:
labels_directory_path:
is_directory_recursive:
file_type: Default: "csv"
Adapt the data loader according to file type. For now, it only supports csv and parquet file formats.
data_format: Default: 2D
Used to refer to the relevant "recon" files, 2D for 2D pixel array, 3D for time series input,
batch_size: Default: 32
The no. of data points to be included in a single batch.
file_count: Default: None
To limit the no. of .csv files to be used for training.
If set to None, all files will be considered as legitimate inputs.
labels_list: Default: "cotAlpha"
Input column name or list of column names to be used as label input to the neural network.
to_standardize: If set to True, it ensures that batches are normalized prior to being used as inputs
for training.
Default: False
input_shape: Default: (13,21) for image input to a 2D feedforward neural network.
To reshape the input array per the requirements of the network training.
"""
if file_type not in ["csv", "parquet"]:
raise ValueError("file_type can only be \"csv\" or \"parquet\"!")
self.file_type = file_type
self.recon_files = glob.glob(
data_directory_path + "recon" + data_format + "*." + file_type,
recursive=is_directory_recursive
)
self.recon_files.sort()
if file_count != None:
if not files_from_end:
self.recon_files = self.recon_files[:file_count]
else:
self.recon_files = self.recon_files[-file_count:]
self.label_files = [
labels_directory_path + recon_file.split('/')[-1].replace("recon" + data_format, "labels") for recon_file in self.recon_files
]
self.file_offsets = [0]
self.dataset_mean = None
self.dataset_std = None
for afile in self.recon_files:
if self.file_type == "csv":
adf = pd.read_csv(afile).dropna()
elif self.file_type == "parquet":
adf = pd.read_parquet(afile).dropna()
self.file_offsets.append(self.file_offsets[-1] + len(adf))
x = adf.values.reshape((-1, *input_shape))
if transpose is not None:
x = x.transpose(transpose)
amean, avariance = self.get_mean_and_variance(x)
if self.dataset_mean is None:
self.dataset_mean = amean
self.dataset_std = avariance
else:
self.dataset_mean += amean
self.dataset_std += avariance
self.dataset_mean = self.dataset_mean / len(self.recon_files)
self.dataset_std = np.sqrt(self.dataset_std / len(self.recon_files))
self.file_offsets = np.array(self.file_offsets)
self.batch_size = batch_size
self.labels_list = labels_list
self.input_shape = input_shape
self.transpose = transpose
self.to_standardize = to_standardize
self.include_y_local = include_y_local
self.current_file_index = None
self.current_dataframes = None
self.on_epoch_end()
def get_mean_and_variance(self, x):
"""Applies the normalization configuration in-place to a batch of
inputs.
`x` is changed in-place since the function is mainly used internally
to standardize images and feed them to your network.
Args:
x: Batch of inputs to be normalized.
Returns:
The inputs, normalized.
"""
return np.mean(x, axis=0, keepdims=True), np.var(x, axis=0, keepdims=True) + 1e-10
def standardize(self, x):
"""Applies the normalization configuration in-place to a batch of
inputs.
`x` is changed in-place since the function is mainly used internally
to standardize images and feed them to your network.
Args:
x: Batch of inputs to be normalized.
Returns:
The inputs, normalized.
"""
return (x - self.dataset_mean)/self.dataset_std
def on_epoch_end(self):
"""
Inherited from the parent class.
Used to reset indices but not of significance in this datagenerator.
"""
pass
def __getitem__(self, batch_index):
"""
Used to fetch a batch of inputs (X,y) for the network's training.
"""
index = batch_index * self.batch_size # absolute *event* index
file_index = np.arange(self.file_offsets.size)[index < self.file_offsets][0] - 1 # first index is 0!
index = index - self.file_offsets[file_index] # relative event index in file
batch_size = min(index + self.batch_size, self.file_offsets[file_index + 1] - self.file_offsets[file_index])
if file_index != self.current_file_index:
self.current_file_index = file_index
if self.file_type == "csv":
recon_df = pd.read_csv(self.recon_files[file_index])
labels_df = pd.read_csv(self.label_files[file_index])[self.labels_list]
elif self.file_type == "parquet":
recon_df = pd.read_parquet(self.recon_files[file_index])
labels_df = pd.read_parquet(self.label_files[file_index], columns=self.labels_list)
has_nans = np.any(np.isnan(recon_df.values), axis=1)
has_nans = np.arange(recon_df.shape[0])[has_nans]
recon_df_raw = recon_df.drop(has_nans)
labels_df_raw = labels_df.drop(has_nans)
joined_df = recon_df_raw.join(labels_df_raw).sample(frac=1).reset_index(drop=True)
self.current_dataframes = (joined_df[recon_df_raw.columns], joined_df[labels_df_raw.columns])
del joined_df
del recon_df
del labels_df
del recon_df_raw
del labels_df_raw
recon_df, labels_df = self.current_dataframes
X = recon_df.iloc[index:batch_size].values.reshape((-1, *self.input_shape))
y = labels_df.iloc[index:batch_size].values
if self.transpose is not None:
X = X.transpose(self.transpose)
if self.to_standardize:
X = self.standardize(X)
if self.include_y_local:
y_local = labels_df.iloc[chosen_idxs]["y-local"].values
return [X, y_local], y
else:
return X, y
def __len__(self):
return self.file_offsets[-1] // self.batch_size
# set batch sizes and number of input files for train and test like this
batch_size = 100
val_batch_size = 100
train_file_size = 142
val_file_size = 16
# make your data generators (example for fully 3D data)
training_generator = CustomDataGenerator(
data_directory_path = "/Users/lgray/smart-pixels/dataset8/unflipped-positive/data/",
labels_directory_path = "/Users/lgray/smart-pixels/dataset8/unflipped-positive/labels/",
is_directory_recursive = False,
file_type = "parquet",
data_format = "3D",
batch_size = batch_size,
file_count = train_file_size,
to_standardize= True,
include_y_local= False,
labels_list = ['x-midplane','y-midplane','cotAlpha','cotBeta'],
input_shape = (20,13,21),
transpose = (0,2,3,1),
)
validation_generator = CustomDataGenerator(
data_directory_path = "/Users/lgray/smart-pixels/dataset8/unflipped-positive/data/",
labels_directory_path = "/Users/lgray/smart-pixels/dataset8/unflipped-positive/labels/",
is_directory_recursive = False,
file_type = "parquet",
data_format = "3D",
batch_size = val_batch_size,
file_count = val_file_size,
to_standardize= True,
include_y_local= False,
labels_list = ['x-midplane','y-midplane','cotAlpha','cotBeta'],
input_shape = (20,13,21),
transpose = (0,2,3,1),
files_from_end=True, # take the files from the *end* of the directory so we don't overlap statistics
)
# finally later you put it into model like fit like this
# training
es = EarlyStopping(
patience=20,
restore_best_weights=True
)
history = model.fit(x=training_generator,
validation_data=validation_generator,
callbacks=[es],
epochs=1000,
shuffle=False, # important!!!!! shuffling now occurs within the data-loader, do not turn on here
verbose=1
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment