Skip to content

Instantly share code, notes, and snippets.

@casperdcl
Last active April 20, 2021 11:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save casperdcl/e0c2f5bff1d731d353c07b5f5f6226ee to your computer and use it in GitHub Desktop.
Save casperdcl/e0c2f5bff1d731d353c07b5f5f6226ee to your computer and use it in GitHub Desktop.
Radiol. 290(3) 649-656
"""
Residual U-net implementation based on [1].
Usage:
>>> from chen2019 import network
>>> # input_data.shape == (num_slices, slice_height, slice_width, num_channels)
>>> model = network(input_data.shape[1:])
>>> model.fit(input_data, output_date, epochs=100, batch_size=input_data.shape[0] // 4, ...)
TODO:
- do we need more epochs?
- do we use triplets of convolutional blocks rather than pairs?
- do we use a different batch size?
[1] K. T. Chen et al. 2019 Radiol. 290(3) 649-656
"Ultra-Low-Dose 18F-Florbetaben Amyloid PET Imaging Using Deep Learning with Multi-Contrast MRI Inputs"
"""
from tensorflow import keras
__author__ = "Casper da Costa-Luis <casper.dcl@kcl.ac.uk>"
def network(input_shape, residual_input_channel=1, lr=2e-4, dtype="float32"):
"""
residual_input_channel : input channel index to use for residual addition
"""
x = inputs = keras.layers.Input(input_shape, dtype=dtype)
def block(x, filters):
x = keras.layers.Conv2D(filters, 3, padding="same", use_bias=False, dtype=dtype)(x)
x = keras.layers.BatchNormalization(dtype=dtype)(x)
x = keras.layers.ReLU(dtype=dtype)(x)
return x
# U-net
filters = [16, 32, 64, 128]
## encode
convs = []
for i in filters[:-1]:
x = block(x, i)
x = block(x, i)
# TODO: do we need a third block?
convs.append(x)
x = keras.layers.MaxPool2D(dtype=dtype, padding="same")(x)
x = block(x, filters[-1])
x = block(x, filters[-1])
## decode
for i in filters[:-1][::-1]:
x = keras.layers.UpSampling2D(interpolation="bilinear", dtype=dtype)(x)
x = keras.layers.Concatenate()([x, convs.pop()])
x = block(x, i)
x = block(x, i)
x = keras.layers.Conv2D(1, 1, padding="same", dtype=dtype, name="residual")(x)
x = keras.layers.Add(name="generated")(
[inputs[..., residual_input_channel : residual_input_channel + 1], x]
)
model = keras.Model(inputs=inputs, outputs=x)
opt = keras.optimizers.Adam(lr)
model.compile(opt, loss="mae")
model.summary()
return model
@casperdcl
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment