Skip to content

Instantly share code, notes, and snippets.

@ha7ilm
Created May 16, 2023 15:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ha7ilm/d8b946b8acfa0bfb35303e63e043fe7c to your computer and use it in GitHub Desktop.
Save ha7ilm/d8b946b8acfa0bfb35303e63e043fe7c to your computer and use it in GitHub Desktop.
#We only need to normalize the inputs of a neural network, not necessarily the outputs.
#However, the choice of the activation function is important, e.g. it doesn't work with `jnp.tanh`.
import jax.numpy as jnp
from jax import grad, jit, vmap
import jax
from jax import random
from flax import linen as nn
from flax.training import train_state
from optax import adam
import plotly.graph_objects as go
# Generate the data
x = jnp.linspace(-jnp.pi, jnp.pi, 200).reshape(-1, 1) # Reshape x to (N, 1)
y = 1000*jnp.sin(x)+3000
# Normalize inputs
x = (x - jnp.mean(x)) / jnp.std(x)
# Create the neural network
class SineModel(nn.Module):
def setup(self):
self.layer1 = nn.Dense(features=32)
self.layer2 = nn.Dense(features=32)
self.output_layer = nn.Dense(features=1)
def __call__(self, x):
x = jax.nn.softplus(self.layer1(x))
x = jax.nn.softplus(self.layer2(x))
return self.output_layer(x)
model = SineModel()
# Define loss function
def loss_fn(params, batch):
inputs, targets = batch
preds = model.apply(params, inputs)
return jnp.mean((targets - preds)**2)
# Initialize model and optimizer
rng = random.PRNGKey(0)
params = model.init(rng, x)
optimizer = adam(learning_rate=0.01)
# Define training step
@jit
def train_step(state, batch):
grads = grad(loss_fn)(state.params, batch)
return state.apply_gradients(grads=grads)
# Prepare the data
data = (x, y)
# Training loop
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)
for i in range(10000):
state = train_step(state, data)
if i % 100 == 0:
print('Loss at step {}: {}'.format(i, loss_fn(state.params, data)))
# Plot data and model output
predictions = model.apply(state.params, x)
fig = go.Figure()
fig.add_trace(go.Scatter(x=x.ravel(), y=y.ravel(), mode='markers', name='Data'))
fig.add_trace(go.Scatter(x=x.ravel(), y=predictions.ravel(), mode='lines', name='Model Output'))
fig.update_layout(title='Sine Function and Model Output', xaxis_title='X', yaxis_title='Y')
fig.write_html('sineout.html')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment