Last active
June 8, 2022 14:32
-
-
Save karhunenloeve/58c77cafa21210972c98f4142745faef to your computer and use it in GitHub Desktop.
Code for a general class of autoencoders using tensorflow.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import config as cfg | |
import tensorflow as tf | |
from tensorflow.keras import Model | |
from tensorflow.keras.preprocessing import image_dataset_from_directory | |
from keras.layers import Input, BatchNormalization, Conv2D, Conv2DTranspose, Dense, Add | |
from collections import OrderedDict | |
BatchNormalization_settings = cfg.BatchNormalization_settings | |
Conv_2D_settings = cfg.Conv_2D_settings | |
Conv_2D_Transpose_settings = cfg.Conv_2D_Transpose_settings | |
hyper_parameters = cfg.hyper_parameters | |
Fit_settings = cfg.Fit_settings | |
Compilation_settings = cfg.Compilation_settings | |
class Autoencoder(Model): | |
def __init__(self): | |
super().__init__() | |
self._model = self._create_model() | |
@staticmethod | |
def _conv_2d(x, layer): | |
return BatchNormalization(**BatchNormalization_settings)( | |
Conv2D(**Conv_2D_settings, name="conv_2d_" + str(layer))(x) | |
) | |
@staticmethod | |
def _conv_2d_transpose(x, layer): | |
return BatchNormalization(**BatchNormalization_settings)( | |
Conv2DTranspose( | |
**Conv_2D_Transpose_settings, name="conv_2d_transposed_" + str(layer) | |
)(x) | |
) | |
def _encoder(self, net, layers): | |
for i in range(1, layers + 1): | |
if i == 1: | |
net["conv2d_" + str(i)] = self._conv_2d(net["input"], i) | |
else: | |
net["conv2d_" + str(i)] = self._conv_2d(net["conv2d_" + str(i - 1)], i) | |
def _encoder(self, net, layers): | |
for i in range(1, layers + 1): | |
if i == 1: | |
net["conv2d_" + str(i)] = self._conv_2d(net["input"], i) | |
else: | |
net["conv2d_" + str(i)] = self._conv_2d(net["conv2d_" + str(i - 1)], i) | |
def _decoder(self, net, layers): | |
for i in range(1, layers + 1): | |
if i == 1: | |
net["conv2d_transpose_" + str(i)] = self._conv_2d_transpose( | |
net["conv2d_" + str(layers)], i | |
) | |
else: | |
net["conv2d_transpose_" + str(i)] = Add()( | |
[ | |
self._conv_2d_transpose( | |
net["conv2d_transpose_" + str(i - 1)], i | |
), | |
net["conv2d_" + str(layers - ( i - 1))], | |
] | |
) | |
def _create_model(self): | |
ordered_net = OrderedDict({"input": Input(hyper_parameters["shape"])}) | |
self._encoder(ordered_net, hyper_parameters["layers"]) | |
self._decoder(ordered_net, hyper_parameters["layers"]) | |
outputs = Dense(units=hyper_parameters["shape"][2])( | |
ordered_net[next(reversed(ordered_net))] | |
) | |
return Model(inputs=ordered_net["input"], outputs=outputs) | |
@classmethod | |
def _create_data(cls, path): | |
return image_dataset_from_directory( | |
path, | |
labels=None, | |
label_mode=None, | |
color_mode=hyper_parameters["colormode"], | |
batch_size=hyper_parameters["batch_size"], | |
image_size=hyper_parameters["image_size"], | |
shuffle=True, | |
) | |
@staticmethod | |
def _get_layer(model, name=None, index=None): | |
return model.get_layer(name, index) | |
@staticmethod | |
def _combined_generator(generator_x, generator_y): | |
for batch_x, batch_y in zip(generator_x, generator_y): | |
yield batch_x, batch_y | |
@tf.autograph.experimental.do_not_convert | |
def train(self): | |
for e in range(Fit_settings["epochs"]): | |
print("Actual epoch:", e) | |
x_train = self._create_data(hyper_parameters["x_train_dir"]) | |
y_train = self._create_data(hyper_parameters["y_train_dir"]) | |
self._model.compile(**Compilation_settings) | |
self._model.summary() | |
self._model.fit(self._combined_generator(x_train, y_train), **Fit_settings) | |
def validate(self): | |
x_val = self._create_data(hyper_parameters["x_val_dir"]) | |
y_val = self._create_data(hyper_parameters["y_val_dir"]) | |
self._model.evaluate(self._combined_generator(x_val, y_val)) | |
if __name__ == "__main__": | |
neural_network = Autoencoder() | |
neural_network.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Here is the
config.py
: