-
-
Save dryglicki/a17c2df4fd89f8681b0fc69d1eaef95e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Trying a new approach. Modelling this system like I'm seeing this: | |
# https://github.com/keras-team/keras/blob/master/keras/layers/rnn/conv_lstm.py | |
import os ; import sys | |
import torch | |
os.environ['KERAS_BACKEND'] = 'torch' | |
#import tensorflow as tf | |
#os.environ['KERAS_BACKEND'] = 'tensorflow' | |
import keras as K | |
import keras.layers as KL | |
from keras.layers import RNN, InputSpec | |
from keras.src.layers.rnn.dropout_rnn_cell import DropoutRNNCell | |
from keras.src.ops import operation_utils | |
class ConvGRU2DCell(KL.Layer): | |
''' | |
Simple cell for a 2D Convolution GRU. | |
Built from PyTorch example found here: https://github.com/openclimatefix/skillful_nowcasting/blob/main/dgmr/layers/ConvGRU.py | |
''' | |
def __init__(self, | |
filters: int, | |
rank: int = 2, | |
kernel_size: int | list = 3, | |
strides: int = 1, | |
dilation_rate: int = 1, | |
use_bias: bool = True, | |
output_gate_activation: str = 'relu', | |
data_format: str = 'channels_last', | |
power_iterations: int = 1, | |
use_spectral_normalization: bool = True, | |
dropout: float = 0.0, | |
**kwargs): | |
super(ConvGRU2DCell, self).__init__(**kwargs) | |
self.filters = filters | |
self.rank = rank | |
if isinstance(kernel_size, int): | |
self.kernel_size = (kernel_size,) * self.rank | |
else: | |
self.kernel_size = kernel_size | |
if isinstance(strides, int): | |
self.strides = (strides,) * self.rank | |
else: | |
self.strides = strides | |
self.dilation_rate = dilation_rate | |
self.data_format = data_format | |
self.output_gate_activation = output_gate_activation | |
self.use_bias = use_bias | |
self.use_spectral_normalization = use_spectral_normalization | |
self.power_iterations = power_iterations | |
self.dropout = min(1.0, max(0.0, dropout)) | |
self.initializer = K.initializers.GlorotUniform() | |
if self.use_spectral_normalization: | |
self.read_gate_conv = KL.SpectralNormalization( | |
KL.Conv2D(filters, kernel_size = kernel_size, strides = 1, padding = 'same', | |
kernel_initializer = self.initializer, use_bias = self.use_bias), | |
power_iterations = self.power_iterations) | |
self.update_gate_conv = KL.SpectralNormalization( | |
KL.Conv2D(filters, kernel_size = kernel_size, strides = 1, padding = 'same', | |
kernel_initializer = self.initializer, use_bias = self.use_bias, | |
bias_initializer = K.initializers.Constant(1.0)), | |
power_iterations = self.power_iterations) | |
self.output_conv = KL.SpectralNormalization( | |
KL.Conv2D(filters, kernel_size = kernel_size, strides = 1, padding = 'same', | |
kernel_initializer = self.initializer, use_bias = self.use_bias), | |
power_iterations = self.power_iterations) | |
else: | |
self.read_gate_conv = K.layers.Conv2D(filters, kernel_size = kernel_size, strides = 1, padding = 'same', | |
kernel_initializer = self.initializer, use_bias = self.use_bias) | |
self.update_gate_conv = K.layers.Conv2D(filters, kernel_size = kernel_size, strides = 1, padding = 'same', | |
kernel_initializer = self.initializer, use_bias = self.use_bias, | |
bias_initializer = K.initializers.Constant(1.0)) | |
self.output_conv = K.layers.Conv2D(filters, kernel_size = kernel_size, strides = 1, padding = 'same', | |
kernel_initializer = self.initializer, use_bias = self.use_bias) | |
self.sigmoid = K.layers.Activation('sigmoid') | |
self.oga = K.layers.Activation(self.output_gate_activation) | |
self.padding = 'same' | |
self.input_spec = InputSpec(ndim = self.rank + 2) | |
self.state_size = -1 # Custom, defined in methods | |
def build(self, inputs_shape, states_shape = None): | |
self.spatial_dims = inputs_shape[1:-1] | |
# I will freely admit to not understanding why this works. | |
# The cell should only be seeing [B, H, W, C], not the whole 5-D series | |
self.input_spec = InputSpec( | |
ndim = self.rank+2, shape = (None,) + inputs_shape[1:]) | |
self.input_dim = inputs_shape[-1] | |
self.kernel_shape = self.kernel_size + (self.input_dim, self.filters) | |
self.built = True | |
# For use with ConvRNN API | |
# Deprecated as of Keras-3 | |
# def input_conv(self, x, w, b=None, padding = 'same', **kwargs): | |
# return tf.nn.conv2d(x, w, strides = 1, padding = 'SAME', data_format = 'NHWC') | |
def call(self, inputs, hidden_state): | |
x = inputs | |
prev_state = hidden_state[0] | |
xh = KL.Concatenate()([x, prev_state]) | |
read_gate = self.sigmoid(self.read_gate_conv(xh)) | |
update_gate = self.sigmoid(self.update_gate_conv(xh)) | |
gated_input = KL.Concatenate(name='GRU_concat')([x, read_gate * prev_state]) | |
c = self.oga(self.output_conv(gated_input)) | |
# if self.flip: | |
out = (1.0 - update_gate) * prev_state + (update_gate * c) | |
# else: | |
# out = update_gate * prev_state + (1.0 - update_gate) * c | |
new_state = out | |
return out, [new_state] | |
def compute_output_shape(self, inputs_shape, states_shape = None): | |
# inputs_shape = tf.TensorShape(inputs_shape).as_list() | |
# return tf.TensorShape(inputs_shape[:-1] + [self.filters]), [tf.TensorShape(inputs_shape[:-1] + [self.filters])] | |
conv_output_shape = operation_utils.compute_conv_output_shape( | |
inputs_shape, | |
self.filters, | |
self.kernel_size, | |
strides=self.strides, | |
padding=self.padding, | |
data_format=self.data_format, | |
dilation_rate=self.dilation_rate, | |
) | |
return conv_output_shape, [conv_output_shape] | |
# def get_initial_state(self, inputs): | |
# ABSOLUTELY DO NOT USE. ALWAYS PASS INITIAL_STATE. | |
# The old version worked on inputs. Now, it uses batch_size. The old way allowed one | |
# to use tf.zeros_like. | |
def get_initial_state(self, batch_size = None): | |
input_shape = (batch_size,) + self.spatial_dims + (self.input_dim,) | |
state_shape = self.compute_output_shape(input_shape)[0] | |
return [ ops.zeros(state_shape, dtype=self.compute_dtype) ] | |
def get_config(self): | |
config = super().get_config().copy() | |
config.update({ | |
'filters' : self.filters, | |
'kernel_size' : self.kernel_size, | |
'num_iter' : self.num_iter, | |
'spectral_normalization' : self.spectral_normalization, | |
'output_gate_activation' : self.output_gate_activation, | |
'data_format' : self.data_format, | |
'flip' : self.flip | |
}) | |
return config | |
class ConvGRU2D(RNN): | |
def __init__( | |
self, | |
filters: int, | |
kernel_size: int | list = 3, | |
strides: int | list = 1, | |
use_bias: bool = True, | |
output_gate_activation: str = 'relu', | |
use_spectral_normalization: str = True, | |
power_iterations: int = 1, | |
return_sequences: bool = True, | |
return_state: bool = False, | |
go_backwards: bool = False, | |
stateful: bool = False, | |
rank: int = 2, | |
**kwargs): | |
self.return_sequences = return_sequences | |
self.return_state = return_state | |
self.go_backwards = go_backwards | |
self.stateful = stateful | |
cell = ConvGRU2DCell( | |
filters, | |
rank = rank, | |
kernel_size = kernel_size, | |
use_bias = use_bias, | |
strides = strides, | |
output_gate_activation = output_gate_activation, | |
use_spectral_normalization = use_spectral_normalization, | |
power_iterations = power_iterations) | |
super().__init__( | |
cell, | |
return_sequences = return_sequences, | |
return_state = return_state, | |
go_backwards = go_backwards, | |
stateful = stateful, | |
**kwargs) | |
self.input_spec = InputSpec(ndim = rank + 3) | |
def call(self, sequences, initial_state = None, mask = None, training = False): | |
return super().call( | |
sequences, initial_state = initial_state, mask = mask, training = training) | |
def compute_output_shape(self, sequences_shape, initial_state_shape = None): | |
batch_size = sequences_shape[0] | |
steps = sequences_shape[1] | |
step_shape = (batch_size,) + sequences_shape[2:] | |
state_shape = self.cell.compute_output_shape(step_shape)[0][1:] | |
if self.return_sequences: | |
output_shape = ( | |
batch_size, | |
steps, | |
) + state_shape | |
else: | |
output_shape = (batch_size,) + state_shape | |
if self.return_state: | |
batched_state_shape = (batch_size,) + state_shape | |
return output_shape, batched_state_shape, batched_state_shape | |
return output_shape | |
@property | |
def filters(self): | |
return self.cell.filters | |
@property | |
def kernel_size(self): | |
return self.cell.kernel_size | |
@property | |
def strides(self): | |
return self.cell.strides | |
@property | |
def use_bias(self): | |
return self.cell.use_bias | |
@property | |
def rank(self): | |
return self.cell.rank | |
@property | |
def output_gate_activation(self): | |
return self.cell.output_gate_activation | |
@property | |
def data_format(self): | |
return self.cell.data_format | |
@property | |
def use_spectral_normalization(self): | |
return self.cell.use_spectral_normalization | |
@property | |
def power_iterations(self): | |
return self.cell.power_iterations | |
@property | |
def dropout(self): | |
return self.cell.dropout | |
################################################# | |
BATCH = [3] | |
NUM_TIMES = [4] | |
IMAGE_SIZE = [64, 64, 6] | |
SEQUENCE = NUM_TIMES + IMAGE_SIZE | |
INPUT_ARRAY = BATCH + SEQUENCE | |
print("Creating dummy ICs.") | |
seed_gen = K.random.SeedGenerator(seed = 42) | |
values = K.random.normal(shape = INPUT_ARRAY, seed = seed_gen) | |
initial_state = K.random.normal(shape = BATCH + IMAGE_SIZE, seed = seed_gen) | |
y = K.random.normal(shape = BATCH + NUM_TIMES + [64,64,1]) * 3.0 | |
NAN_sequence = NUM_TIMES + [None, None] + [6] | |
NAN_image = [None, None] + [6] | |
# Build dummy model | |
print("Building dummy model.") | |
seq_inputs = KL.Input(shape = NAN_sequence) | |
init_inputs = KL.Input(shape = NAN_image) | |
xs = KL.TimeDistributed(KL.Conv2D(32, | |
kernel_size = 3, padding= 'same'))(seq_inputs) | |
xi = KL.Conv2D(32, kernel_size = 3, padding = 'same')(init_inputs) | |
x = ConvGRU2D(filters = 32, kernel_size = 3, | |
use_spectral_normalization = True, power_iterations = 2, | |
return_sequences = True)(xs, initial_state = [xi]) | |
x = KL.TimeDistributed(KL.Conv2D(1,kernel_size=3, padding='same'))(x) | |
outputs = x | |
model = K.models.Model(inputs = [seq_inputs, init_inputs], outputs = [outputs]) | |
################################################ | |
optimizer = K.optimizers.AdamW() | |
model.compile(optimizer = optimizer, | |
loss = 'mse') | |
model.fit([values, initial_state],y, epochs = 5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment