Skip to content

Instantly share code, notes, and snippets.

@TeraBytesMemory
Last active July 29, 2020 01:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TeraBytesMemory/4b0bd82c0c2c304e1011e4e047094000 to your computer and use it in GitHub Desktop.
Save TeraBytesMemory/4b0bd82c0c2c304e1011e4e047094000 to your computer and use it in GitHub Desktop.
[WIP]mean_teacher_in_keras
# https://arxiv.org/pdf/1703.01780.pdf
import tensorflow as tf
from tensorflow.keras.callbacks import Callback
class TeacherOutputGetter(object):
def __init__(self, semi_supervised_X):
self._X = semi_supervised_X
self._teacher = None
def register_teacher(self, teacher):
self._teacher = teacher
@property
def X(self):
return self._X
def get_y(self):
return self._teacher.predict(self._X)
class MeanTeacherCallback(tf.keras.callbacks.Callback):
def __init__(self, model, alpha=0.99):
self.alpha = alpha
self.student = model
self.teacher = tf.keras.models.clone_model(model)
self.teacher.trainable = False
def on_epoch_end(self, epoch, logs):
weights = [self.teacher.get_weights(), self.student.get_weights()]
new_weights = [
self.alpha * t_weights + (1 - self.alpha) * s_weights
for t_weights, s_weights in zip(*weights)
]
self.teacher.set_weights(new_weights)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment