Skip to content

Instantly share code, notes, and snippets.

@thierryherrmann
Created November 13, 2019 04:03
Show Gist options
  • Save thierryherrmann/550906c4f97ff685cf21b0418a485063 to your computer and use it in GitHub Desktop.
Save thierryherrmann/550906c4f97ff685cf21b0418a485063 to your computer and use it in GitHub Desktop.
Tensorflow_PR_33229
# This was tested from tf-nightly 2.1.0.dev20191111 (Linux Ubuntu 18.04)
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
import os
import tensorflow as tf
from tensorflow import keras
print(tf.__version__)
print(keras.__version__)
def get_housing_dataset():
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
housing = fetch_california_housing()
X_train_full, X_test, y_train_full, y_test = train_test_split(
housing.data, housing.target)
X_train, X_valid, y_train, y_valid = train_test_split(
X_train_full, y_train_full)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train).astype(np.float32)
X_valid = scaler.transform(X_valid).astype(np.float32)
X_test = scaler.transform(X_test).astype(np.float32)
y_train = y_train.astype(np.float32)
y_valid = y_valid.astype(np.float32)
y_test = y_test.astype(np.float32)
return X_train, X_valid, X_test, y_train, y_valid, y_test
tf.random.set_seed(1); np.random.seed(2)
X_train, X_valid, X_test, y_train, y_valid, y_test = get_housing_dataset()
class HuberLoss(keras.losses.Loss):
def __init__(self, threshold=1.0, **kwargs):
super().__init__(**kwargs)
self.threshold = threshold
@tf.function
def call(self, y_true, y_pred):
error = tf.abs(y_true - y_pred)
is_small_error = error <= self.threshold
squared_loss = tf.square(error) / 2
linear_loss = error * self.threshold - 0.5 * self.threshold**2
return tf.where(is_small_error, squared_loss, linear_loss)
def get_config(self):
cfg = super().get_config()
cfg['threshold'] = self.threshold
return cfg
@tf.function
def huber_fn(y_true, y_pred):
error = y_true - y_pred
is_small_error = tf.abs(error) < 1
squared_loss = tf.square(error) / 2
linear_loss = tf.abs(error) - 0.5
return tf.where(is_small_error, squared_loss, linear_loss)
class HuberMetric(keras.metrics.Metric):
def __init__(self, threshold=1.0, **kwargs):
super().__init__(**kwargs) # handles base args (e.g., dtype)
self.threshold = threshold
print('__init__() called')
def huber_fn(y_true, y_pred):
# will be called when calling model.compile() : the returned object will be
# incorporated in the graph of objects and this method will never be called afterwards (e.g. during training)
# It will be also called when reloading a saved model (see below)
error = y_true - y_pred
is_small_error = tf.abs(error) < threshold
squared_loss = tf.square(error) / 2
linear_loss = threshold * tf.abs(error) - threshold**2 / 2
print('huber_fn() called')
return tf.where(is_small_error, squared_loss, linear_loss)
self.huber_fn = huber_fn
# create variables necessary for the streaming metric. Keras keep track of any
# tf.Variable set as an attribute (like any other trackable object like layers or models)
self.total = tf.Variable(0., name='total')
self.count = tf.Variable(0., name='count')
# equivalent code
# self.total = self.add_weight("total", initializer="zeros")
# self.count = self.add_weight("count", initializer="zeros")
# Those variables ARE RESET automatically by keras between epochs by calling
# Metric.reset_states() that sets to 0 all variables
# See: https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Metric
# Keras tracks any tf.Variable that is set as an attribute (and more generally, any
# "trackable" object, such as layers or models).
def update_state(self, y_true, y_pred, sample_weight=None):
# called when model.compile() is called. The result is the set of operations
# and is added to the graph of operations
print('update_state() called')
metric = self.huber_fn(y_true, y_pred)
self.total.assign_add(tf.reduce_sum(metric))
self.count.assign_add(tf.cast(tf.size(y_true), tf.float32))
def result(self):
# Also called when model.compile() is called. The result is the set of operations
# and is added to the graph of operations
print('result() called')
return self.total / self.count
def get_config(self):
print('get_config()')
base_config = super().get_config()
return {**base_config, "threshold": self.threshold}
model = keras.Sequential([
keras.layers.Dense(30, activation='relu', input_shape=X_train.shape[1:]),
keras.layers.Dense(1)
])
if True:
# LOSS custom class: FIXED by the PR (thanks @omalleyt12)
model.compile(loss=HuberLoss(2.0), optimizer="sgd")
model.fit(X_train, y_train, epochs=1, validation_data=(X_valid, y_valid))
model.save('model.h5')
model = keras.models.load_model('model.h5', custom_objects={'HuberLoss': HuberLoss})
if False:
# LOSS custom function: was working before the PR, still works with the PR
model.compile(loss=huber_fn, optimizer="sgd")
model.fit(X_train, y_train, epochs=1, validation_data=(X_valid, y_valid))
model.save('model.h5')
model = keras.models.load_model('model.h5', custom_objects={'huber_fn': huber_fn})
if False:
# METRICS custom class: was broken before the PR, still broken with the PR
model.compile(loss=keras.losses.mean_squared_error, optimizer="sgd", metrics=[HuberMetric(threshold=.2)])
model.fit(X_train, y_train, epochs=1, validation_data=(X_valid, y_valid))
model.save('model.h5')
model = keras.models.load_model('model.h5', custom_objects={'HuberMetric':HuberMetric})
if False:
# METRICS custom function: was working before the PR, still works with the PR
model.compile(loss=keras.losses.mean_squared_error, optimizer="sgd", metrics=[huber_fn])
model.fit(X_train, y_train, epochs=1, validation_data=(X_valid, y_valid))
model.save('model.h5')
model = keras.models.load_model('model.h5', custom_objects={'huber_fn':huber_fn})
# continue training with the reloaded model
model.fit(X_train, y_train, epochs=1, validation_data=(X_valid, y_valid))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment