This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Custom losses""" | |
| import tensorflow as tf | |
| # pylint: disable = attribute-defined-outside-init, no-name-in-module, unexpected-keyword-arg | |
| # pylint: disable = no-value-for-parameter | |
| from tensorflow.python.ops import state_ops as tf_state_ops | |
| class SeeSawWeightCalculator(tf.keras.layers.Layer): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """module for implementing time to vector encoding""" | |
| import tensorflow as tf | |
| import tensorflow.keras.layers as KL | |
| class Time2Vec(KL.Layer): | |
| """time2vector encoding layer""" | |
| def __init__(self, kernel: int = 64, activation: str = "sin") -> None: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| from typing import Tuple | |
| import tensorflow as tf | |
| import tensorflow.keras.backend as K | |
| import tensorflow.keras.layers as KL | |
| import tensorflow.keras.models as KM | |
| import tensorflow_datasets as tfds | |
| from tensorflow.python.keras.callbacks import ModelCheckpoint |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Module implementing custom model.fit for model training""" | |
| import tensorflow as tf | |
| import tensorflow.keras.models as KM | |
| import tensorflow_probability as tfp | |
| class Trainer(KM.Model): # pylint: disable=too-many-ancestors | |
| """Custom function for model training using model.fit() | |
| With functionality for selective back propagation for accelerated training. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| import tensorflow.keras.layers as KL | |
| import tensorflow.keras.models as KM | |
| def channel_attention(features: int, reduction: int = 16, name: str = "") -> KM.Model: | |
| """channel attention model | |
| Args: | |
| features (int): number of features for incoming tensor | |
| reduction (int, optional): Reduction ratio for the MLP to squeeze information across channels. Defaults to 16. |