Skip to content

Instantly share code, notes, and snippets.

@avivajpeyi
Created July 13, 2023 04:14
Show Gist options
  • Save avivajpeyi/3977a91b25c6d30b75f9b9568623b0bc to your computer and use it in GitHub Desktop.
Save avivajpeyi/3977a91b25c6d30b75f9b9568623b0bc to your computer and use it in GitHub Desktop.
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
def model(x, a2, a3):
return a2 * x + a3 * x ** 2
def generate_training_data(x, a2_range, a3_range):
data = []
a2_draw = []
a3_draw = []
a2_samples = np.random.uniform(a2_range[0], a2_range[1], 100)
a3_samples = np.random.uniform(a3_range[0], a3_range[1], 100)
for a2 in a2_samples:
for a3 in a3_samples:
y = model(x, a2, a3)
a2_draw.append(a2)
a3_draw.append(a3)
data.append(y)
return np.array([a2_draw, a3_draw]), np.array(data)
def generate_ml_model(in_params, out_timeseries):
model = tf.keras.Sequential([
tf.keras.layers.Dense(in_params.shape[0], activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(out_timeseries.shape[1])
])
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
loss='mse',
metrics=['mae', 'mse'])
history = model.fit(in_params.T, out_timeseries, epochs=20, verbose=1)
# plot loss from history
plt.plot(history.history['loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.savefig('loss.png')
return model
if __name__ == '__main__':
x = np.linspace(0, 10, 100)
a2_range = (-5, 5)
a3_range = (-5, 5)
in_data, out_data = generate_training_data(x, a2_range, a3_range)
ml_model = generate_ml_model(in_data, out_data)
model_eval = ml_model.predict(np.array([[0.5, 0.5]]))[0]
plt.figure()
plt.plot(x, model_eval, label='ML')
plt.plot(x, model(x, 0.5, 0.5), label='True', ls='--')
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment