Created
November 13, 2019 04:03
-
-
Save thierryherrmann/550906c4f97ff685cf21b0418a485063 to your computer and use it in GitHub Desktop.
Tensorflow_PR_33229
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This was tested from tf-nightly 2.1.0.dev20191111 (Linux Ubuntu 18.04) | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
from scipy import stats | |
import os | |
import tensorflow as tf | |
from tensorflow import keras | |
print(tf.__version__) | |
print(keras.__version__) | |
def get_housing_dataset(): | |
from sklearn.datasets import fetch_california_housing | |
from sklearn.model_selection import train_test_split | |
from sklearn.preprocessing import StandardScaler | |
housing = fetch_california_housing() | |
X_train_full, X_test, y_train_full, y_test = train_test_split( | |
housing.data, housing.target) | |
X_train, X_valid, y_train, y_valid = train_test_split( | |
X_train_full, y_train_full) | |
scaler = StandardScaler() | |
X_train = scaler.fit_transform(X_train).astype(np.float32) | |
X_valid = scaler.transform(X_valid).astype(np.float32) | |
X_test = scaler.transform(X_test).astype(np.float32) | |
y_train = y_train.astype(np.float32) | |
y_valid = y_valid.astype(np.float32) | |
y_test = y_test.astype(np.float32) | |
return X_train, X_valid, X_test, y_train, y_valid, y_test | |
tf.random.set_seed(1); np.random.seed(2) | |
X_train, X_valid, X_test, y_train, y_valid, y_test = get_housing_dataset() | |
class HuberLoss(keras.losses.Loss): | |
def __init__(self, threshold=1.0, **kwargs): | |
super().__init__(**kwargs) | |
self.threshold = threshold | |
@tf.function | |
def call(self, y_true, y_pred): | |
error = tf.abs(y_true - y_pred) | |
is_small_error = error <= self.threshold | |
squared_loss = tf.square(error) / 2 | |
linear_loss = error * self.threshold - 0.5 * self.threshold**2 | |
return tf.where(is_small_error, squared_loss, linear_loss) | |
def get_config(self): | |
cfg = super().get_config() | |
cfg['threshold'] = self.threshold | |
return cfg | |
@tf.function | |
def huber_fn(y_true, y_pred): | |
error = y_true - y_pred | |
is_small_error = tf.abs(error) < 1 | |
squared_loss = tf.square(error) / 2 | |
linear_loss = tf.abs(error) - 0.5 | |
return tf.where(is_small_error, squared_loss, linear_loss) | |
class HuberMetric(keras.metrics.Metric): | |
def __init__(self, threshold=1.0, **kwargs): | |
super().__init__(**kwargs) # handles base args (e.g., dtype) | |
self.threshold = threshold | |
print('__init__() called') | |
def huber_fn(y_true, y_pred): | |
# will be called when calling model.compile() : the returned object will be | |
# incorporated in the graph of objects and this method will never be called afterwards (e.g. during training) | |
# It will be also called when reloading a saved model (see below) | |
error = y_true - y_pred | |
is_small_error = tf.abs(error) < threshold | |
squared_loss = tf.square(error) / 2 | |
linear_loss = threshold * tf.abs(error) - threshold**2 / 2 | |
print('huber_fn() called') | |
return tf.where(is_small_error, squared_loss, linear_loss) | |
self.huber_fn = huber_fn | |
# create variables necessary for the streaming metric. Keras keep track of any | |
# tf.Variable set as an attribute (like any other trackable object like layers or models) | |
self.total = tf.Variable(0., name='total') | |
self.count = tf.Variable(0., name='count') | |
# equivalent code | |
# self.total = self.add_weight("total", initializer="zeros") | |
# self.count = self.add_weight("count", initializer="zeros") | |
# Those variables ARE RESET automatically by keras between epochs by calling | |
# Metric.reset_states() that sets to 0 all variables | |
# See: https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Metric | |
# Keras tracks any tf.Variable that is set as an attribute (and more generally, any | |
# "trackable" object, such as layers or models). | |
def update_state(self, y_true, y_pred, sample_weight=None): | |
# called when model.compile() is called. The result is the set of operations | |
# and is added to the graph of operations | |
print('update_state() called') | |
metric = self.huber_fn(y_true, y_pred) | |
self.total.assign_add(tf.reduce_sum(metric)) | |
self.count.assign_add(tf.cast(tf.size(y_true), tf.float32)) | |
def result(self): | |
# Also called when model.compile() is called. The result is the set of operations | |
# and is added to the graph of operations | |
print('result() called') | |
return self.total / self.count | |
def get_config(self): | |
print('get_config()') | |
base_config = super().get_config() | |
return {**base_config, "threshold": self.threshold} | |
model = keras.Sequential([ | |
keras.layers.Dense(30, activation='relu', input_shape=X_train.shape[1:]), | |
keras.layers.Dense(1) | |
]) | |
if True: | |
# LOSS custom class: FIXED by the PR (thanks @omalleyt12) | |
model.compile(loss=HuberLoss(2.0), optimizer="sgd") | |
model.fit(X_train, y_train, epochs=1, validation_data=(X_valid, y_valid)) | |
model.save('model.h5') | |
model = keras.models.load_model('model.h5', custom_objects={'HuberLoss': HuberLoss}) | |
if False: | |
# LOSS custom function: was working before the PR, still works with the PR | |
model.compile(loss=huber_fn, optimizer="sgd") | |
model.fit(X_train, y_train, epochs=1, validation_data=(X_valid, y_valid)) | |
model.save('model.h5') | |
model = keras.models.load_model('model.h5', custom_objects={'huber_fn': huber_fn}) | |
if False: | |
# METRICS custom class: was broken before the PR, still broken with the PR | |
model.compile(loss=keras.losses.mean_squared_error, optimizer="sgd", metrics=[HuberMetric(threshold=.2)]) | |
model.fit(X_train, y_train, epochs=1, validation_data=(X_valid, y_valid)) | |
model.save('model.h5') | |
model = keras.models.load_model('model.h5', custom_objects={'HuberMetric':HuberMetric}) | |
if False: | |
# METRICS custom function: was working before the PR, still works with the PR | |
model.compile(loss=keras.losses.mean_squared_error, optimizer="sgd", metrics=[huber_fn]) | |
model.fit(X_train, y_train, epochs=1, validation_data=(X_valid, y_valid)) | |
model.save('model.h5') | |
model = keras.models.load_model('model.h5', custom_objects={'huber_fn':huber_fn}) | |
# continue training with the reloaded model | |
model.fit(X_train, y_train, epochs=1, validation_data=(X_valid, y_valid)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment