Created
March 26, 2024 06:24
-
-
Save Raychani1/81ca78982e67f8c05eea2cac18cd7ae5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io | |
import logging | |
import os | |
from logging import Logger | |
from keras.layers import ( | |
ConvLSTM2D, | |
Dense, | |
Dropout, | |
Flatten, | |
Input, | |
MaxPooling3D, | |
TimeDistributed, | |
) | |
from keras.models import Model, Sequential | |
# Constant variables | |
SUBSET_SIZE = 10 | |
SEQUENCE_LENGTH = 20 | |
IMAGE_HEIGHT = 64 | |
IMAGE_WIDTH = 64 | |
OUTPUT_LAYER_ACTIVATION_FUNCTION = "softmax" | |
def setup_logger() -> Logger: | |
logger = logging.getLogger(__name__) | |
log_file_handler = logging.FileHandler(os.path.join(os.getcwd(), "my_log.log")) | |
log_file_handler.setFormatter( | |
logging.Formatter("%(asctime)s [%(levelname)s] %(message)s") | |
) | |
logger.setLevel(logging.INFO) | |
logger.addHandler(log_file_handler) | |
log_stream_handler = logging.StreamHandler() | |
log_stream_handler.setFormatter( | |
logging.Formatter("%(asctime)s [%(levelname)s] %(message)s") | |
) | |
logger.addHandler(log_stream_handler) | |
return logger | |
def get_summary_string(model: Model) -> str: | |
# SOURCE: https://stackoverflow.com/a/53668338/14319439 | |
string_list = [] | |
model.summary(print_fn=lambda x: string_list.append(x)) | |
short_model_summary = "\n".join(string_list) | |
return short_model_summary | |
def get_model_summary(model): | |
# SOURCE: https://stackoverflow.com/a/53937848/14319439 | |
stream = io.StringIO() | |
model.summary(print_fn=lambda x: stream.write(x + "\n")) | |
summary_string = stream.getvalue() | |
stream.close() | |
return summary_string | |
class MyModel: | |
def __init__(self, logger: Logger) -> None: | |
self.__logger = logger | |
self.__model = self.__build_model() | |
def __build_model(self) -> Model: | |
self.__logger.info("Building Model") | |
model = Sequential() | |
model.add(Input(shape=(SEQUENCE_LENGTH, IMAGE_HEIGHT, IMAGE_WIDTH, 3))) | |
for iteration, number_of_filters in enumerate([4, 8, 14, 16]): | |
model.add( | |
ConvLSTM2D( | |
filters=number_of_filters, | |
kernel_size=(3, 3), | |
activation="tanh", | |
data_format="channels_last", | |
recurrent_dropout=0.2, | |
return_sequences=True, | |
) | |
) | |
model.add( | |
MaxPooling3D( | |
pool_size=(1, 2, 2), padding="same", data_format="channels_last" | |
) | |
) | |
if iteration != 3: | |
model.add(TimeDistributed(layer=Dropout(rate=0.2))) | |
model.add(Flatten()) | |
model.add(Dense(units=SUBSET_SIZE, activation=OUTPUT_LAYER_ACTIVATION_FUNCTION)) | |
self.__logger.info("Finished building Model") | |
# Option 1 | |
self.__logger.info(f"Model architecture: {get_summary_string(model=model)}") | |
# # Option 2 | |
# self.__logger.info( | |
# f'Model architecture: {get_model_summary(model=model)}' | |
# ) | |
# # Option 3 | |
# # SOURCE: https://stackoverflow.com/a/50077092/14319439 | |
# model.summary(print_fn=self.__logger.info) | |
# Option 4 | |
# model.summary( | |
# print_fn=lambda x: self.__logger.info(f'Model architecture: {x}') | |
# ) | |
return model | |
if __name__ == "__main__": | |
logger = setup_logger() | |
model = MyModel(logger=logger) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
absl-py==2.1.0 | |
astunparse==1.6.3 | |
certifi==2024.2.2 | |
charset-normalizer==3.3.2 | |
flatbuffers==24.3.25 | |
gast==0.5.4 | |
google-pasta==0.2.0 | |
grpcio==1.62.1 | |
h5py==3.10.0 | |
idna==3.6 | |
keras==3.1.1 | |
libclang==18.1.1 | |
Markdown==3.6 | |
markdown-it-py==3.0.0 | |
MarkupSafe==2.1.5 | |
mdurl==0.1.2 | |
ml-dtypes==0.3.2 | |
namex==0.0.7 | |
numpy==1.26.4 | |
opt-einsum==3.3.0 | |
optree==0.11.0 | |
packaging==24.0 | |
protobuf==4.25.3 | |
Pygments==2.17.2 | |
requests==2.31.0 | |
rich==13.7.1 | |
six==1.16.0 | |
tensorboard==2.16.2 | |
tensorboard-data-server==0.7.2 | |
tensorflow==2.16.1 | |
tensorflow-io-gcs-filesystem==0.36.0 | |
termcolor==2.4.0 | |
typing_extensions==4.10.0 | |
urllib3==2.2.1 | |
Werkzeug==3.0.1 | |
wrapt==1.16.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment