Skip to content

Instantly share code, notes, and snippets.

@Raychani1
Created March 26, 2024 06:24
Show Gist options
  • Save Raychani1/81ca78982e67f8c05eea2cac18cd7ae5 to your computer and use it in GitHub Desktop.
Save Raychani1/81ca78982e67f8c05eea2cac18cd7ae5 to your computer and use it in GitHub Desktop.
import io
import logging
import os
from logging import Logger
from keras.layers import (
ConvLSTM2D,
Dense,
Dropout,
Flatten,
Input,
MaxPooling3D,
TimeDistributed,
)
from keras.models import Model, Sequential
# Constant variables
SUBSET_SIZE = 10
SEQUENCE_LENGTH = 20
IMAGE_HEIGHT = 64
IMAGE_WIDTH = 64
OUTPUT_LAYER_ACTIVATION_FUNCTION = "softmax"
def setup_logger() -> Logger:
logger = logging.getLogger(__name__)
log_file_handler = logging.FileHandler(os.path.join(os.getcwd(), "my_log.log"))
log_file_handler.setFormatter(
logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
)
logger.setLevel(logging.INFO)
logger.addHandler(log_file_handler)
log_stream_handler = logging.StreamHandler()
log_stream_handler.setFormatter(
logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
)
logger.addHandler(log_stream_handler)
return logger
def get_summary_string(model: Model) -> str:
# SOURCE: https://stackoverflow.com/a/53668338/14319439
string_list = []
model.summary(print_fn=lambda x: string_list.append(x))
short_model_summary = "\n".join(string_list)
return short_model_summary
def get_model_summary(model):
# SOURCE: https://stackoverflow.com/a/53937848/14319439
stream = io.StringIO()
model.summary(print_fn=lambda x: stream.write(x + "\n"))
summary_string = stream.getvalue()
stream.close()
return summary_string
class MyModel:
def __init__(self, logger: Logger) -> None:
self.__logger = logger
self.__model = self.__build_model()
def __build_model(self) -> Model:
self.__logger.info("Building Model")
model = Sequential()
model.add(Input(shape=(SEQUENCE_LENGTH, IMAGE_HEIGHT, IMAGE_WIDTH, 3)))
for iteration, number_of_filters in enumerate([4, 8, 14, 16]):
model.add(
ConvLSTM2D(
filters=number_of_filters,
kernel_size=(3, 3),
activation="tanh",
data_format="channels_last",
recurrent_dropout=0.2,
return_sequences=True,
)
)
model.add(
MaxPooling3D(
pool_size=(1, 2, 2), padding="same", data_format="channels_last"
)
)
if iteration != 3:
model.add(TimeDistributed(layer=Dropout(rate=0.2)))
model.add(Flatten())
model.add(Dense(units=SUBSET_SIZE, activation=OUTPUT_LAYER_ACTIVATION_FUNCTION))
self.__logger.info("Finished building Model")
# Option 1
self.__logger.info(f"Model architecture: {get_summary_string(model=model)}")
# # Option 2
# self.__logger.info(
# f'Model architecture: {get_model_summary(model=model)}'
# )
# # Option 3
# # SOURCE: https://stackoverflow.com/a/50077092/14319439
# model.summary(print_fn=self.__logger.info)
# Option 4
# model.summary(
# print_fn=lambda x: self.__logger.info(f'Model architecture: {x}')
# )
return model
if __name__ == "__main__":
logger = setup_logger()
model = MyModel(logger=logger)
absl-py==2.1.0
astunparse==1.6.3
certifi==2024.2.2
charset-normalizer==3.3.2
flatbuffers==24.3.25
gast==0.5.4
google-pasta==0.2.0
grpcio==1.62.1
h5py==3.10.0
idna==3.6
keras==3.1.1
libclang==18.1.1
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
ml-dtypes==0.3.2
namex==0.0.7
numpy==1.26.4
opt-einsum==3.3.0
optree==0.11.0
packaging==24.0
protobuf==4.25.3
Pygments==2.17.2
requests==2.31.0
rich==13.7.1
six==1.16.0
tensorboard==2.16.2
tensorboard-data-server==0.7.2
tensorflow==2.16.1
tensorflow-io-gcs-filesystem==0.36.0
termcolor==2.4.0
typing_extensions==4.10.0
urllib3==2.2.1
Werkzeug==3.0.1
wrapt==1.16.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment