Skip to content

Instantly share code, notes, and snippets.

@adiell
Created September 29, 2020 08:33
Show Gist options
  • Save adiell/ebcfbe0164ce7167d45a43d8fdba715f to your computer and use it in GitHub Desktop.
Save adiell/ebcfbe0164ce7167d45a43d8fdba715f to your computer and use it in GitHub Desktop.
Example for usage of tensorflow lattice on 1d data ("curve fitting") to ensure monotonicity and convexity
import tensorflow_lattice as tfl
import tensorflow
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from tensorflow_estimator.python.estimator.inputs.numpy_io import numpy_input_fn
from tensorflow.python.feature_column.feature_column_v2 import numeric_column
from tensorflow_estimator.python.estimator.canned.dnn import DNNRegressor
n = 100
np.random.seed(231)
X = np.linspace(0,1,n)
coef = np.random.randn(5)
y = np.sum(np.array([c*(1-X)**n for n,c in enumerate(coef)]), axis=0) + 0.1*np.random.randn(n)
train_input_fn = numpy_input_fn(x = {'x': X}, y = y ,
batch_size=4, num_epochs=100, shuffle=True)
feature_analysis_input_fn = numpy_input_fn(x = {'x': X}, y = y ,
batch_size=4, num_epochs=1, shuffle=False)
feature_columns = [numeric_column(key='x', shape=(1,))]
X_pred = np.linspace(min(X), max(X))
eval_input_fn = numpy_input_fn(x={'x': X_pred},
batch_size=1, num_epochs=1, shuffle=False)
model_config = tfl.configs.CalibratedLatticeConfig(
feature_configs=[
tfl.configs.FeatureConfig(
name="x",
lattice_size=2,
monotonicity="increasing",
pwl_calibration_convexity="concave",
pwl_calibration_num_keypoints=20,
regularizer_configs=[
tfl.configs.RegularizerConfig(name="calib_wrinkle", l2=1.0),
],
)
])
tfl_estimator = tfl.estimators.CannedRegressor(
feature_columns=feature_columns,
model_config=model_config,
feature_analysis_input_fn=feature_analysis_input_fn,
optimizer=tensorflow.keras.optimizers.Adam(learning_rate=0.001),
config=tensorflow.estimator.RunConfig(tf_random_seed=42),
)
tfl_estimator.train(input_fn=train_input_fn)
dnn_model = DNNRegressor(
feature_columns = feature_columns,
hidden_units=10*[20],
)
dnn_model.train(train_input_fn)
y_pred = dnn_model.predict(eval_input_fn)
y_pred_nn = [y['predictions'][0] for y in y_pred]
y_pred = tfl_estimator.predict(eval_input_fn)
y_pred_lat = [y['predictions'][0] for y in y_pred]
fig, axis = plt.subplots(2,1, figsize = (8,10))
axis[0].scatter(X,y)
axis[0].plot(X_pred,y_pred_nn,'k')
axis[0].set_title("Vanilla NN")
axis[1].scatter(X,y)
axis[1].plot(X_pred,y_pred_lat,'k')
axis[1].set_title("Lattice")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment