Created
October 31, 2018 11:01
-
-
Save IFeelBloated/6e85b251f66941b7eb50f987c95ce69b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from keras import layers, initializers, models, optimizers, backend, utils | |
import tensorflow as tf | |
def Mirror(x, Padding=1): | |
return layers.Lambda(lambda x: tf.pad(x, [[0, 0], [0, 0], [Padding, Padding], [Padding, Padding]], mode="Reflect"))(x) | |
def BilinearUpsample(x, Scale=2, Permute=False): | |
ChannelCount = int(x.shape[1]) | |
Settings = dict(filters=ChannelCount, kernel_size=Scale * 2, strides=Scale, padding='same', use_bias=False, trainable=False) | |
class BilinearKernel(initializers.Initializer): | |
def __call__(self, shape, dtype=None): | |
Kernel2D = np.empty((Scale * 2, Scale * 2)) | |
x, y = np.mgrid[1:2 * Scale + 1:2, 1:2 * Scale + 1:2] | |
Kernel2D[:Scale, :Scale] = (x * y) / (4 * Scale ** 2) | |
Kernel2D[:Scale, Scale:] = np.fliplr(Kernel2D[:Scale, :Scale]) | |
Kernel2D[Scale:, :Scale] = np.flipud(Kernel2D[:Scale, :Scale]) | |
Kernel2D[Scale:, Scale:] = np.flipud(Kernel2D[:Scale, Scale:]) | |
Kernel = np.zeros(shape) | |
for x in range(ChannelCount): | |
Kernel[:, :, x, x] = Kernel2D | |
return Kernel | |
x = Mirror(x, 16) | |
x = layers.Convolution2DTranspose(kernel_initializer=BilinearKernel(), **Settings)(x) | |
x = layers.Cropping2D(32)(x) | |
return layers.Lambda(lambda x: tf.space_to_depth(x, Scale, data_format='NCHW'))(x) if Permute else x | |
def DepthToSpace(x, Scale=2): | |
return layers.Lambda(lambda x: tf.depth_to_space(x, Scale, data_format='NCHW'))(x) | |
def MSRInitializer(Alpha=0, WeightScale=1, Seed=0): | |
ScalingFactor = WeightScale ** 2 | |
return initializers.VarianceScaling(scale=ScalingFactor * 2 / (1 + Alpha ** 2), mode='fan_in', distribution='normal', seed=Seed) | |
def Convolution(x, ConstructedLayers, FilterCount, Name, Shortcut=None, ReceptiveField=3, Activation=True, WeightScale=0.05, PReLuAlpha=0.25, Faucet=None): | |
WeightScale = WeightScale if Shortcut is not None else 1 | |
PReLuAlpha = PReLuAlpha if Activation else 1 | |
ConvolutionalLayerKey = 'Convolution' + Name | |
PReLuLayerKey = 'PReLu' + Name | |
ConvolutionSettings = dict(filters=FilterCount, kernel_size=ReceptiveField, strides=1, activation=None, padding='valid', name=ConvolutionalLayerKey) | |
PReLuSettings = dict(alpha_initializer=initializers.constant(PReLuAlpha), shared_axes=[2, 3], name=PReLuLayerKey) | |
if ConvolutionalLayerKey not in ConstructedLayers: | |
ConstructedLayers[ConvolutionalLayerKey] = layers.Conv2D(kernel_initializer=MSRInitializer(PReLuAlpha, WeightScale), **ConvolutionSettings) | |
if Activation and PReLuLayerKey not in ConstructedLayers: | |
ConstructedLayers[PReLuLayerKey] = layers.advanced_activations.PReLU(**PReLuSettings) | |
x = Mirror(x, (ReceptiveField - 1) // 2) if ReceptiveField > 1 else x | |
x = ConstructedLayers[ConvolutionalLayerKey](x) | |
x = layers.multiply([x, Faucet]) if Faucet is not None else x | |
x = layers.add([x, Shortcut]) if Shortcut is not None else x | |
return [x, ConstructedLayers[PReLuLayerKey](x)] if Activation else x | |
def SequentialConvolution(x, ConstructedLayers, Name, Shortcut=None, ReceptiveField=3, Granularity=1, Windowed=True, SubsequenceCount=None, WeightScale=0.05, PReLuAlpha=0.25, Faucet=None): | |
def ShiftFarthestMapsOutOfView(x): | |
return layers.Lambda(lambda x: x[:, Granularity:, :, :])(x) | |
def GetGranularSubsequence(Index): | |
return layers.Lambda(lambda x: x[:, Index * Granularity:(Index + 1) * Granularity, :, :])(Shortcut) | |
ConvolutionSettings = dict(ConstructedLayers=ConstructedLayers, FilterCount=Granularity, ReceptiveField=ReceptiveField, WeightScale=WeightScale, PReLuAlpha=PReLuAlpha, Faucet=Faucet) | |
SubsequenceCount = SubsequenceCount if SubsequenceCount is not None else int(x.shape[1]) // Granularity | |
SequenceView = x | |
RawGranularSubsequences = [] | |
ActivatedGranularSubsequences = [] | |
for x in range(SubsequenceCount): | |
CurrentShortcutView = GetGranularSubsequence(x) if Shortcut is not None else None | |
RawGranularSubsequence, ActivatedGranularSubsequence = Convolution(SequenceView, Name='Subsequence' + str(x) + Name, Shortcut=CurrentShortcutView, **ConvolutionSettings) | |
RawGranularSubsequences += [RawGranularSubsequence] | |
ActivatedGranularSubsequences += [ActivatedGranularSubsequence] | |
SequenceView = ShiftFarthestMapsOutOfView(SequenceView) if Windowed else SequenceView | |
SequenceView = layers.concatenate([SequenceView, ActivatedGranularSubsequence], axis=1) | |
return layers.concatenate(RawGranularSubsequences, axis=1), layers.concatenate(ActivatedGranularSubsequences, axis=1) | |
def SequentialResidualBlock(x, ConstructedLayers, Name, Shortcut=None, ReceptiveField=3, Granularity=1, WeightScale=0.05, PReLuAlpha=0.25, Faucet=None): | |
Settings = dict(ConstructedLayers=ConstructedLayers, ReceptiveField=ReceptiveField, Granularity=Granularity, WeightScale=WeightScale, PReLuAlpha=PReLuAlpha) | |
_, x = SequentialConvolution(x, Name = 'SequenceA' + Name, **Settings) | |
return SequentialConvolution(x, Name = 'SequenceB' + Name, Shortcut=Shortcut, Faucet=Faucet, **Settings) | |
def SuperResolutionReconstruct(LowResolutionPatch, ConstructedLayers, ExistingBlocks, AdditionalBlocks, WindowSize=64, Granularity=1, WeightScale=0.05, PReLuAlpha=0.25, Faucet=None): | |
ConvolutionSettings = dict(ConstructedLayers=ConstructedLayers, FilterCount=4, WeightScale=WeightScale, PReLuAlpha=PReLuAlpha, Activation=False) | |
SequentialConvolutionSettings = dict(ConstructedLayers=ConstructedLayers, Granularity=Granularity, WeightScale=WeightScale, PReLuAlpha=PReLuAlpha) | |
RawFeature, ActivatedFeature = SequentialConvolution(LowResolutionPatch, Name='InitialFeatureGeneration', Windowed=False, SubsequenceCount=WindowSize // Granularity, **SequentialConvolutionSettings) | |
for x in range(ExistingBlocks): | |
RawFeature, ActivatedFeature = SequentialResidualBlock(ActivatedFeature, Name='FeatureExtraction' + str(x), Shortcut=RawFeature, **SequentialConvolutionSettings) | |
for x in range(ExistingBlocks, ExistingBlocks + AdditionalBlocks): | |
RawFeature, ActivatedFeature = SequentialResidualBlock(ActivatedFeature, Name='FeatureExtraction' + str(x), Shortcut=RawFeature, Faucet=Faucet, **SequentialConvolutionSettings) | |
ReconstructedPatch = Convolution(ActivatedFeature, Name='FinalFusion', Shortcut=BilinearUpsample(LowResolutionPatch, Permute=True), **ConvolutionSettings) | |
return DepthToSpace(ReconstructedPatch) | |
def GetModel2x(Height, Width, ExistingBlocks, AdditionalBlocks, ConstructedLayers=None, WindowSize=64, Granularity=1, WeightScale=0.05, PReLuAlpha=0.25): | |
LowResolutionPatch = layers.Input((1, Height, Width)) | |
Faucet = layers.Input((1,)) if AdditionalBlocks > 0 else None | |
ConstructedLayers = ConstructedLayers if ConstructedLayers is not None else {} | |
ReconstructedPatch = SuperResolutionReconstruct(LowResolutionPatch, ConstructedLayers, ExistingBlocks, AdditionalBlocks, WindowSize, Granularity, WeightScale, PReLuAlpha, Faucet) | |
return models.Model(inputs=[LowResolutionPatch, Faucet] if AdditionalBlocks > 0 else LowResolutionPatch, outputs=ReconstructedPatch), ConstructedLayers | |
backend.set_image_data_format('channels_first') | |
PatternSize = 96 | |
TargetSize = PatternSize * 2 | |
BatchSize = 8 | |
Length = 1000 | |
BlockCount = 16 | |
Optimizer = optimizers.Adam(lr=1e-4, epsilon=1e-8, decay=0) | |
GPU = 8 | |
LowRes = np.zeros((Length, 1, PatternSize, PatternSize)) | |
Target = np.zeros((Length, 1, TargetSize, TargetSize)) | |
Model, ConstructedLayers = GetModel2x(PatternSize, PatternSize, BlockCount, 0) | |
ParallelModel = utils.multi_gpu_model(Model, GPU, cpu_merge=False) # merging on GPU for nvlink | |
ParallelModel.compile(Optimizer, 'mae') | |
Model.summary() | |
ParallelModel.fit(LowRes, Target, batch_size=BatchSize, epochs=800) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment