Skip to content

Instantly share code, notes, and snippets.

@mwitiderrick
Last active December 18, 2023 13:56
Show Gist options
  • Save mwitiderrick/83cf60b4818bdf65521569fa9ab92741 to your computer and use it in GitHub Desktop.
Save mwitiderrick/83cf60b4818bdf65521569fa9ab92741 to your computer and use it in GitHub Desktop.
import flax
from flax import linen as nn
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1))
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=2)(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment