Skip to content

Instantly share code, notes, and snippets.

@dryglicki
Created April 16, 2024 13:27
Show Gist options
  • Save dryglicki/a17c2df4fd89f8681b0fc69d1eaef95e to your computer and use it in GitHub Desktop.
Save dryglicki/a17c2df4fd89f8681b0fc69d1eaef95e to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# Trying a new approach. Modelling this system like I'm seeing this:
# https://github.com/keras-team/keras/blob/master/keras/layers/rnn/conv_lstm.py
import os ; import sys
import torch
os.environ['KERAS_BACKEND'] = 'torch'
#import tensorflow as tf
#os.environ['KERAS_BACKEND'] = 'tensorflow'
import keras as K
import keras.layers as KL
from keras.layers import RNN, InputSpec
from keras.src.layers.rnn.dropout_rnn_cell import DropoutRNNCell
from keras.src.ops import operation_utils
class ConvGRU2DCell(KL.Layer):
'''
Simple cell for a 2D Convolution GRU.
Built from PyTorch example found here: https://github.com/openclimatefix/skillful_nowcasting/blob/main/dgmr/layers/ConvGRU.py
'''
def __init__(self,
filters: int,
rank: int = 2,
kernel_size: int | list = 3,
strides: int = 1,
dilation_rate: int = 1,
use_bias: bool = True,
output_gate_activation: str = 'relu',
data_format: str = 'channels_last',
power_iterations: int = 1,
use_spectral_normalization: bool = True,
dropout: float = 0.0,
**kwargs):
super(ConvGRU2DCell, self).__init__(**kwargs)
self.filters = filters
self.rank = rank
if isinstance(kernel_size, int):
self.kernel_size = (kernel_size,) * self.rank
else:
self.kernel_size = kernel_size
if isinstance(strides, int):
self.strides = (strides,) * self.rank
else:
self.strides = strides
self.dilation_rate = dilation_rate
self.data_format = data_format
self.output_gate_activation = output_gate_activation
self.use_bias = use_bias
self.use_spectral_normalization = use_spectral_normalization
self.power_iterations = power_iterations
self.dropout = min(1.0, max(0.0, dropout))
self.initializer = K.initializers.GlorotUniform()
if self.use_spectral_normalization:
self.read_gate_conv = KL.SpectralNormalization(
KL.Conv2D(filters, kernel_size = kernel_size, strides = 1, padding = 'same',
kernel_initializer = self.initializer, use_bias = self.use_bias),
power_iterations = self.power_iterations)
self.update_gate_conv = KL.SpectralNormalization(
KL.Conv2D(filters, kernel_size = kernel_size, strides = 1, padding = 'same',
kernel_initializer = self.initializer, use_bias = self.use_bias,
bias_initializer = K.initializers.Constant(1.0)),
power_iterations = self.power_iterations)
self.output_conv = KL.SpectralNormalization(
KL.Conv2D(filters, kernel_size = kernel_size, strides = 1, padding = 'same',
kernel_initializer = self.initializer, use_bias = self.use_bias),
power_iterations = self.power_iterations)
else:
self.read_gate_conv = K.layers.Conv2D(filters, kernel_size = kernel_size, strides = 1, padding = 'same',
kernel_initializer = self.initializer, use_bias = self.use_bias)
self.update_gate_conv = K.layers.Conv2D(filters, kernel_size = kernel_size, strides = 1, padding = 'same',
kernel_initializer = self.initializer, use_bias = self.use_bias,
bias_initializer = K.initializers.Constant(1.0))
self.output_conv = K.layers.Conv2D(filters, kernel_size = kernel_size, strides = 1, padding = 'same',
kernel_initializer = self.initializer, use_bias = self.use_bias)
self.sigmoid = K.layers.Activation('sigmoid')
self.oga = K.layers.Activation(self.output_gate_activation)
self.padding = 'same'
self.input_spec = InputSpec(ndim = self.rank + 2)
self.state_size = -1 # Custom, defined in methods
def build(self, inputs_shape, states_shape = None):
self.spatial_dims = inputs_shape[1:-1]
# I will freely admit to not understanding why this works.
# The cell should only be seeing [B, H, W, C], not the whole 5-D series
self.input_spec = InputSpec(
ndim = self.rank+2, shape = (None,) + inputs_shape[1:])
self.input_dim = inputs_shape[-1]
self.kernel_shape = self.kernel_size + (self.input_dim, self.filters)
self.built = True
# For use with ConvRNN API
# Deprecated as of Keras-3
# def input_conv(self, x, w, b=None, padding = 'same', **kwargs):
# return tf.nn.conv2d(x, w, strides = 1, padding = 'SAME', data_format = 'NHWC')
def call(self, inputs, hidden_state):
x = inputs
prev_state = hidden_state[0]
xh = KL.Concatenate()([x, prev_state])
read_gate = self.sigmoid(self.read_gate_conv(xh))
update_gate = self.sigmoid(self.update_gate_conv(xh))
gated_input = KL.Concatenate(name='GRU_concat')([x, read_gate * prev_state])
c = self.oga(self.output_conv(gated_input))
# if self.flip:
out = (1.0 - update_gate) * prev_state + (update_gate * c)
# else:
# out = update_gate * prev_state + (1.0 - update_gate) * c
new_state = out
return out, [new_state]
def compute_output_shape(self, inputs_shape, states_shape = None):
# inputs_shape = tf.TensorShape(inputs_shape).as_list()
# return tf.TensorShape(inputs_shape[:-1] + [self.filters]), [tf.TensorShape(inputs_shape[:-1] + [self.filters])]
conv_output_shape = operation_utils.compute_conv_output_shape(
inputs_shape,
self.filters,
self.kernel_size,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate,
)
return conv_output_shape, [conv_output_shape]
# def get_initial_state(self, inputs):
# ABSOLUTELY DO NOT USE. ALWAYS PASS INITIAL_STATE.
# The old version worked on inputs. Now, it uses batch_size. The old way allowed one
# to use tf.zeros_like.
def get_initial_state(self, batch_size = None):
input_shape = (batch_size,) + self.spatial_dims + (self.input_dim,)
state_shape = self.compute_output_shape(input_shape)[0]
return [ ops.zeros(state_shape, dtype=self.compute_dtype) ]
def get_config(self):
config = super().get_config().copy()
config.update({
'filters' : self.filters,
'kernel_size' : self.kernel_size,
'num_iter' : self.num_iter,
'spectral_normalization' : self.spectral_normalization,
'output_gate_activation' : self.output_gate_activation,
'data_format' : self.data_format,
'flip' : self.flip
})
return config
class ConvGRU2D(RNN):
def __init__(
self,
filters: int,
kernel_size: int | list = 3,
strides: int | list = 1,
use_bias: bool = True,
output_gate_activation: str = 'relu',
use_spectral_normalization: str = True,
power_iterations: int = 1,
return_sequences: bool = True,
return_state: bool = False,
go_backwards: bool = False,
stateful: bool = False,
rank: int = 2,
**kwargs):
self.return_sequences = return_sequences
self.return_state = return_state
self.go_backwards = go_backwards
self.stateful = stateful
cell = ConvGRU2DCell(
filters,
rank = rank,
kernel_size = kernel_size,
use_bias = use_bias,
strides = strides,
output_gate_activation = output_gate_activation,
use_spectral_normalization = use_spectral_normalization,
power_iterations = power_iterations)
super().__init__(
cell,
return_sequences = return_sequences,
return_state = return_state,
go_backwards = go_backwards,
stateful = stateful,
**kwargs)
self.input_spec = InputSpec(ndim = rank + 3)
def call(self, sequences, initial_state = None, mask = None, training = False):
return super().call(
sequences, initial_state = initial_state, mask = mask, training = training)
def compute_output_shape(self, sequences_shape, initial_state_shape = None):
batch_size = sequences_shape[0]
steps = sequences_shape[1]
step_shape = (batch_size,) + sequences_shape[2:]
state_shape = self.cell.compute_output_shape(step_shape)[0][1:]
if self.return_sequences:
output_shape = (
batch_size,
steps,
) + state_shape
else:
output_shape = (batch_size,) + state_shape
if self.return_state:
batched_state_shape = (batch_size,) + state_shape
return output_shape, batched_state_shape, batched_state_shape
return output_shape
@property
def filters(self):
return self.cell.filters
@property
def kernel_size(self):
return self.cell.kernel_size
@property
def strides(self):
return self.cell.strides
@property
def use_bias(self):
return self.cell.use_bias
@property
def rank(self):
return self.cell.rank
@property
def output_gate_activation(self):
return self.cell.output_gate_activation
@property
def data_format(self):
return self.cell.data_format
@property
def use_spectral_normalization(self):
return self.cell.use_spectral_normalization
@property
def power_iterations(self):
return self.cell.power_iterations
@property
def dropout(self):
return self.cell.dropout
#################################################
BATCH = [3]
NUM_TIMES = [4]
IMAGE_SIZE = [64, 64, 6]
SEQUENCE = NUM_TIMES + IMAGE_SIZE
INPUT_ARRAY = BATCH + SEQUENCE
print("Creating dummy ICs.")
seed_gen = K.random.SeedGenerator(seed = 42)
values = K.random.normal(shape = INPUT_ARRAY, seed = seed_gen)
initial_state = K.random.normal(shape = BATCH + IMAGE_SIZE, seed = seed_gen)
y = K.random.normal(shape = BATCH + NUM_TIMES + [64,64,1]) * 3.0
NAN_sequence = NUM_TIMES + [None, None] + [6]
NAN_image = [None, None] + [6]
# Build dummy model
print("Building dummy model.")
seq_inputs = KL.Input(shape = NAN_sequence)
init_inputs = KL.Input(shape = NAN_image)
xs = KL.TimeDistributed(KL.Conv2D(32,
kernel_size = 3, padding= 'same'))(seq_inputs)
xi = KL.Conv2D(32, kernel_size = 3, padding = 'same')(init_inputs)
x = ConvGRU2D(filters = 32, kernel_size = 3,
use_spectral_normalization = True, power_iterations = 2,
return_sequences = True)(xs, initial_state = [xi])
x = KL.TimeDistributed(KL.Conv2D(1,kernel_size=3, padding='same'))(x)
outputs = x
model = K.models.Model(inputs = [seq_inputs, init_inputs], outputs = [outputs])
################################################
optimizer = K.optimizers.AdamW()
model.compile(optimizer = optimizer,
loss = 'mse')
model.fit([values, initial_state],y, epochs = 5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment