Skip to content

Instantly share code, notes, and snippets.

@cytronicoder
Last active November 29, 2024 11:58
Implementation of Kolmogorov-Arnold Network (KAN) with Dynamic Spline Activation Functions in PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
class BSplineActivation(nn.Module):
"""
B-spline activation function for KAN model.
This class implements the B-spline activation function using the Cox-de Boor
recursion formula. The B-spline is defined by its degree and a set of knots.
The activation function is parameterized by coefficients which are learned
during training.
"""
def __init__(self, num_knots=15, degree=3):
"""
Parameters:
num_knots (int): The number of knots in the B-spline basis.
degree (int): The degree of the B-spline.
"""
super(BSplineActivation, self).__init__()
self.degree = degree
self.num_knots = num_knots
# init to small random values for better convergence
self.coefficients = nn.Parameter(torch.randn(num_knots + degree - 1) * 0.01)
self.register_buffer("knots", self.create_knot_vector())
def create_knot_vector(self):
"""
Create the knot vector for the B-spline basis functions.
The knot vector is constructed with 'degree' repeated knots at the start and end
so that the B-spline curve starts and ends at the first and last control points.
"""
knots = [0] * self.degree
knots += torch.linspace(0, 1, self.num_knots).tolist()
knots += [1] * self.degree
return torch.tensor(knots, dtype=torch.float32)
def bspline_basis(self, x, degree, knots, i):
"""
Recursively compute the B-spline basis function for a given degree and knot vector.
Uses the Cox-de Boor recursion formula to compute the basis functions.
Parameters:
x (torch.Tensor): The input tensor for which to evaluate the B-spline basis function.
degree (int): The degree of the B-spline.
knots (torch.Tensor): A 1D tensor representing the knot vector.
i (int): The index of the basis function to compute.
Returns:
torch.Tensor: The computed B-spline basis function values for the input tensor.
"""
if degree == 0:
# zero-degree B-spline --> piecewise constant
return ((x >= knots[i]) & (x < knots[i + 1])).float()
else:
denom1 = knots[i + degree] - knots[i]
denom2 = knots[i + degree + 1] - knots[i + 1]
term1 = 0
if denom1 != 0:
term1 = (
(x - knots[i])
/ denom1
* self.bspline_basis(x, degree - 1, knots, i)
)
term2 = 0
if denom2 != 0:
term2 = (
(knots[i + degree + 1] - x)
/ denom2
* self.bspline_basis(x, degree - 1, knots, i + 1)
)
return term1 + term2
def forward(self, x):
"""
Clamps the input tensor to the range [0, 1] and computes the weighted sum
of B-spline basis functions to produce the output tensor. Each basis
function is weighted by the corresponding coefficient.
"""
x = torch.clamp(x, 0, 1)
y = torch.zeros_like(x)
n = len(self.coefficients)
for i in range(n):
basis = self.bspline_basis(x, self.degree, self.knots, i)
y += self.coefficients[i] * basis
return y
class KANLayer(nn.Module):
"""
KAN layer with B-spline activation functions applied to each edge.
This class represents a fully connected layer where each connection (edge)
has its own unique B-spline activation function.
"""
def __init__(self, input_dim, output_dim, num_knots=15, degree=3):
"""
Parameters:
input_dim (int): The number of input neurons.
output_dim (int): The number of output neurons.
num_knots (int): The number of knots in the B-spline basis for the activations.
degree (int): The degree of the B-spline for the activations.
"""
super(KANLayer, self).__init__()
self.activations = nn.ModuleList(
[
BSplineActivation(num_knots, degree)
for _ in range(input_dim * output_dim)
]
)
self.input_dim = input_dim
self.output_dim = output_dim
def forward(self, x):
"""
Applies the B-spline activation functions on each edge and sums the
contributions from each edge to compute the output.
"""
batch_size = x.size(0)
result = torch.zeros(batch_size, self.output_dim, device=x.device)
for i in range(self.input_dim):
for j in range(self.output_dim):
idx = i * self.output_dim + j
# Apply the activation function for each edge
result[:, j] += self.activations[idx](x[:, i].unsqueeze(1)).squeeze(1)
return result
class KANModel(nn.Module):
"""
Full KAN model with multiple KAN layers.
The class consists of a sequence of KAN layers, each of which uses
B-spline activation functions on its edges.
"""
def __init__(self, layer_dims, num_knots=15, degree=3):
"""
Parameters:
layer_dims (list): A list of integers, where each integer is the number
of neurons in each layer of the model.
num_knots (int): The number of knots in the B-spline basis.
degree (int): The degree of the B-spline.
"""
super(KANModel, self).__init__()
layers = []
for i in range(len(layer_dims) - 1):
layers.append(KANLayer(layer_dims[i], layer_dims[i + 1], num_knots, degree))
self.layers = nn.ModuleList(layers)
def forward(self, x):
"""
Applies each KAN layer sequentially to the input to produce the output.
"""
for layer in self.layers:
x = layer(x)
return x
if __name__ == "__main__":
torch.manual_seed(0)
# model architecture w/ increased capacity
model = KANModel([1, 50, 50, 1], num_knots=15, degree=3)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
loss_func = nn.MSELoss()
x_input = torch.linspace(0, 1, 200).unsqueeze(1)
target = torch.sin(2 * np.pi * x_input) # Sinusoidal target function
# train/val split (80/20)
train_size = int(0.8 * len(x_input))
x_train, x_val = x_input[:train_size], x_input[train_size:]
y_train, y_val = target[:train_size], target[train_size:]
losses = []
val_losses = []
EPOCHS = 100
for epoch in range(EPOCHS):
# training
model.train()
optimizer.zero_grad()
output = model(x_train)
loss = loss_func(output, y_train)
loss.backward()
# gradient clipping to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
losses.append(loss.item())
# validation
model.eval()
with torch.no_grad():
val_output = model(x_val)
val_loss = loss_func(val_output, y_val)
val_losses.append(val_loss.item())
# if (epoch + 1) % 500 == 0:
print(
f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item():.6f}, Val Loss: {val_loss.item():.6f}"
)
# Training and validation loss over time
plt.figure()
plt.plot(losses, label="Training Loss")
plt.plot(val_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Training and Validation Loss Over Time")
plt.show()
# Target function vs. model output
with torch.no_grad():
output = model(x_input)
plt.figure()
plt.plot(x_input.numpy(), target.numpy(), label="Target Function")
plt.plot(x_input.numpy(), output.numpy(), label="Model Output")
plt.xlabel("Input (x)")
plt.ylabel("Output (y)")
plt.legend()
plt.title("KAN Model Approximation of sin(2πx)")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment