Implementation of Kolmogorov-Arnold Network (KAN) with Dynamic Spline Activation Functions in PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
class BSplineActivation(nn.Module):
B-spline activation function for KAN model.
This class implements the B-spline activation function using the Cox-de Boor
recursion formula. The B-spline is defined by its degree and a set of knots.
The activation function is parameterized by coefficients which are learned
during training.
def __init__(self, num_knots=15, degree=3):
num_knots (int): The number of knots in the B-spline basis.
degree (int): The degree of the B-spline.
super(BSplineActivation, self).__init__() = degree
self.num_knots = num_knots
# init to small random values for better convergence
self.coefficients = nn.Parameter(torch.randn(num_knots + degree - 1) * 0.01)
self.register_buffer("knots", self.create_knot_vector())
def create_knot_vector(self):
Create the knot vector for the B-spline basis functions.
The knot vector is constructed with 'degree' repeated knots at the start and end
so that the B-spline curve starts and ends at the first and last control points.
knots = [0] *
knots += torch.linspace(0, 1, self.num_knots).tolist()
knots += [1] *
return torch.tensor(knots, dtype=torch.float32)
def bspline_basis(self, x, degree, knots, i):
Recursively compute the B-spline basis function for a given degree and knot vector.
Uses the Cox-de Boor recursion formula to compute the basis functions.
x (torch.Tensor): The input tensor for which to evaluate the B-spline basis function.
degree (int): The degree of the B-spline.
knots (torch.Tensor): A 1D tensor representing the knot vector.
i (int): The index of the basis function to compute.
torch.Tensor: The computed B-spline basis function values for the input tensor.
if degree == 0:
# zero-degree B-spline --> piecewise constant
return ((x >= knots[i]) & (x < knots[i + 1])).float()
denom1 = knots[i + degree] - knots[i]
denom2 = knots[i + degree + 1] - knots[i + 1]
term1 = 0
if denom1 != 0:
term1 = (
(x - knots[i])
/ denom1
* self.bspline_basis(x, degree - 1, knots, i)
term2 = 0
if denom2 != 0:
term2 = (
(knots[i + degree + 1] - x)
/ denom2
* self.bspline_basis(x, degree - 1, knots, i + 1)
return term1 + term2
def forward(self, x):
Clamps the input tensor to the range [0, 1] and computes the weighted sum
of B-spline basis functions to produce the output tensor. Each basis
function is weighted by the corresponding coefficient.
x = torch.clamp(x, 0, 1)
y = torch.zeros_like(x)
n = len(self.coefficients)
for i in range(n):
basis = self.bspline_basis(x,, self.knots, i)
y += self.coefficients[i] * basis
return y
class KANLayer(nn.Module):
KAN layer with B-spline activation functions applied to each edge.
This class represents a fully connected layer where each connection (edge)
has its own unique B-spline activation function.
def __init__(self, input_dim, output_dim, num_knots=15, degree=3):
input_dim (int): The number of input neurons.
output_dim (int): The number of output neurons.
num_knots (int): The number of knots in the B-spline basis for the activations.
degree (int): The degree of the B-spline for the activations.
super(KANLayer, self).__init__()
self.activations = nn.ModuleList(
BSplineActivation(num_knots, degree)
for _ in range(input_dim * output_dim)
self.input_dim = input_dim
self.output_dim = output_dim
def forward(self, x):
Applies the B-spline activation functions on each edge and sums the
contributions from each edge to compute the output.
batch_size = x.size(0)
result = torch.zeros(batch_size, self.output_dim, device=x.device)
for i in range(self.input_dim):
for j in range(self.output_dim):
idx = i * self.output_dim + j
# Apply the activation function for each edge
result[:, j] += self.activations[idx](x[:, i].unsqueeze(1)).squeeze(1)
return result
class KANModel(nn.Module):
Full KAN model with multiple KAN layers.
The class consists of a sequence of KAN layers, each of which uses
B-spline activation functions on its edges.
def __init__(self, layer_dims, num_knots=15, degree=3):
layer_dims (list): A list of integers, where each integer is the number
of neurons in each layer of the model.
num_knots (int): The number of knots in the B-spline basis.
degree (int): The degree of the B-spline.
super(KANModel, self).__init__()
layers = []
for i in range(len(layer_dims) - 1):
layers.append(KANLayer(layer_dims[i], layer_dims[i + 1], num_knots, degree))
self.layers = nn.ModuleList(layers)
def forward(self, x):
Applies each KAN layer sequentially to the input to produce the output.
for layer in self.layers:
x = layer(x)
return x
if __name__ == "__main__":
# model architecture w/ increased capacity
model = KANModel([1, 50, 50, 1], num_knots=15, degree=3)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
loss_func = nn.MSELoss()
x_input = torch.linspace(0, 1, 200).unsqueeze(1)
target = torch.sin(2 * np.pi * x_input) # Sinusoidal target function
# train/val split (80/20)
train_size = int(0.8 * len(x_input))
x_train, x_val = x_input[:train_size], x_input[train_size:]
y_train, y_val = target[:train_size], target[train_size:]
losses = []
val_losses = []
EPOCHS = 100
for epoch in range(EPOCHS):
# training
output = model(x_train)
loss = loss_func(output, y_train)
# gradient clipping to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# validation
with torch.no_grad():
val_output = model(x_val)
val_loss = loss_func(val_output, y_val)
# if (epoch + 1) % 500 == 0:
f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item():.6f}, Val Loss: {val_loss.item():.6f}"
# Training and validation loss over time
plt.plot(losses, label="Training Loss")
plt.plot(val_losses, label="Validation Loss")
plt.title("Training and Validation Loss Over Time")
# Target function vs. model output
with torch.no_grad():
output = model(x_input)
plt.plot(x_input.numpy(), target.numpy(), label="Target Function")
plt.plot(x_input.numpy(), output.numpy(), label="Model Output")
plt.xlabel("Input (x)")
plt.ylabel("Output (y)")
plt.title("KAN Model Approximation of sin(2πx)")
