Last active
October 4, 2021 14:35
-
-
Save jizhang02/ef8eb45450f3d943fea37c6544d3808c to your computer and use it in GitHub Desktop.
Calculate the theoretical memory of a model in Keras.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
----------------------------------------------- | |
File Name: memory_usage$ | |
Description: memroy usage in theory | |
Author: Jing$ | |
Reference: https://stackoverflow.com/questions/43137288/how-to-determine-needed-memory-of-keras-model | |
Date: 9/28/2021$ | |
----------------------------------------------- | |
''' | |
def get_model_memory_usage(batch_size, model): | |
import numpy as np | |
try: | |
from keras import backend as K | |
except: | |
from tensorflow.keras import backend as K | |
shapes_mem_count = 0 | |
internal_model_mem_count = 0 | |
for l in model.layers: | |
layer_type = l.__class__.__name__ | |
if layer_type == 'Model': | |
internal_model_mem_count += get_model_memory_usage(batch_size, l) | |
single_layer_mem = 1 | |
out_shape = l.output_shape | |
if type(out_shape) is list: | |
out_shape = out_shape[0] | |
for s in out_shape: | |
if s is None: | |
continue | |
single_layer_mem *= s | |
shapes_mem_count += single_layer_mem | |
trainable_count = np.sum([K.count_params(p) for p in model.trainable_weights]) | |
non_trainable_count = np.sum([K.count_params(p) for p in model.non_trainable_weights]) | |
number_size = 4.0 | |
if K.floatx() == 'float16': | |
number_size = 2.0 | |
if K.floatx() == 'float64': | |
number_size = 8.0 | |
total_memory = number_size * (batch_size * shapes_mem_count + trainable_count + non_trainable_count) | |
gbytes = np.round(total_memory / (1024.0 ** 3), 3) + internal_model_mem_count | |
return gbytes | |
# Call the function. (input the batch size and the Keras model) | |
gigabytes = get_model_memory_usage(batch_size,model) | |
print("Memory usage:",gigabytes) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment