Last active
November 29, 2024 11:58
Implementation of Kolmogorov-Arnold Network (KAN) with Dynamic Spline Activation Functions in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import matplotlib.pyplot as plt | |
import numpy as np | |
class BSplineActivation(nn.Module): | |
""" | |
B-spline activation function for KAN model. | |
This class implements the B-spline activation function using the Cox-de Boor | |
recursion formula. The B-spline is defined by its degree and a set of knots. | |
The activation function is parameterized by coefficients which are learned | |
during training. | |
""" | |
def __init__(self, num_knots=15, degree=3): | |
""" | |
Parameters: | |
num_knots (int): The number of knots in the B-spline basis. | |
degree (int): The degree of the B-spline. | |
""" | |
super(BSplineActivation, self).__init__() | |
self.degree = degree | |
self.num_knots = num_knots | |
# init to small random values for better convergence | |
self.coefficients = nn.Parameter(torch.randn(num_knots + degree - 1) * 0.01) | |
self.register_buffer("knots", self.create_knot_vector()) | |
def create_knot_vector(self): | |
""" | |
Create the knot vector for the B-spline basis functions. | |
The knot vector is constructed with 'degree' repeated knots at the start and end | |
so that the B-spline curve starts and ends at the first and last control points. | |
""" | |
knots = [0] * self.degree | |
knots += torch.linspace(0, 1, self.num_knots).tolist() | |
knots += [1] * self.degree | |
return torch.tensor(knots, dtype=torch.float32) | |
def bspline_basis(self, x, degree, knots, i): | |
""" | |
Recursively compute the B-spline basis function for a given degree and knot vector. | |
Uses the Cox-de Boor recursion formula to compute the basis functions. | |
Parameters: | |
x (torch.Tensor): The input tensor for which to evaluate the B-spline basis function. | |
degree (int): The degree of the B-spline. | |
knots (torch.Tensor): A 1D tensor representing the knot vector. | |
i (int): The index of the basis function to compute. | |
Returns: | |
torch.Tensor: The computed B-spline basis function values for the input tensor. | |
""" | |
if degree == 0: | |
# zero-degree B-spline --> piecewise constant | |
return ((x >= knots[i]) & (x < knots[i + 1])).float() | |
else: | |
denom1 = knots[i + degree] - knots[i] | |
denom2 = knots[i + degree + 1] - knots[i + 1] | |
term1 = 0 | |
if denom1 != 0: | |
term1 = ( | |
(x - knots[i]) | |
/ denom1 | |
* self.bspline_basis(x, degree - 1, knots, i) | |
) | |
term2 = 0 | |
if denom2 != 0: | |
term2 = ( | |
(knots[i + degree + 1] - x) | |
/ denom2 | |
* self.bspline_basis(x, degree - 1, knots, i + 1) | |
) | |
return term1 + term2 | |
def forward(self, x): | |
""" | |
Clamps the input tensor to the range [0, 1] and computes the weighted sum | |
of B-spline basis functions to produce the output tensor. Each basis | |
function is weighted by the corresponding coefficient. | |
""" | |
x = torch.clamp(x, 0, 1) | |
y = torch.zeros_like(x) | |
n = len(self.coefficients) | |
for i in range(n): | |
basis = self.bspline_basis(x, self.degree, self.knots, i) | |
y += self.coefficients[i] * basis | |
return y | |
class KANLayer(nn.Module): | |
""" | |
KAN layer with B-spline activation functions applied to each edge. | |
This class represents a fully connected layer where each connection (edge) | |
has its own unique B-spline activation function. | |
""" | |
def __init__(self, input_dim, output_dim, num_knots=15, degree=3): | |
""" | |
Parameters: | |
input_dim (int): The number of input neurons. | |
output_dim (int): The number of output neurons. | |
num_knots (int): The number of knots in the B-spline basis for the activations. | |
degree (int): The degree of the B-spline for the activations. | |
""" | |
super(KANLayer, self).__init__() | |
self.activations = nn.ModuleList( | |
[ | |
BSplineActivation(num_knots, degree) | |
for _ in range(input_dim * output_dim) | |
] | |
) | |
self.input_dim = input_dim | |
self.output_dim = output_dim | |
def forward(self, x): | |
""" | |
Applies the B-spline activation functions on each edge and sums the | |
contributions from each edge to compute the output. | |
""" | |
batch_size = x.size(0) | |
result = torch.zeros(batch_size, self.output_dim, device=x.device) | |
for i in range(self.input_dim): | |
for j in range(self.output_dim): | |
idx = i * self.output_dim + j | |
# Apply the activation function for each edge | |
result[:, j] += self.activations[idx](x[:, i].unsqueeze(1)).squeeze(1) | |
return result | |
class KANModel(nn.Module): | |
""" | |
Full KAN model with multiple KAN layers. | |
The class consists of a sequence of KAN layers, each of which uses | |
B-spline activation functions on its edges. | |
""" | |
def __init__(self, layer_dims, num_knots=15, degree=3): | |
""" | |
Parameters: | |
layer_dims (list): A list of integers, where each integer is the number | |
of neurons in each layer of the model. | |
num_knots (int): The number of knots in the B-spline basis. | |
degree (int): The degree of the B-spline. | |
""" | |
super(KANModel, self).__init__() | |
layers = [] | |
for i in range(len(layer_dims) - 1): | |
layers.append(KANLayer(layer_dims[i], layer_dims[i + 1], num_knots, degree)) | |
self.layers = nn.ModuleList(layers) | |
def forward(self, x): | |
""" | |
Applies each KAN layer sequentially to the input to produce the output. | |
""" | |
for layer in self.layers: | |
x = layer(x) | |
return x | |
if __name__ == "__main__": | |
torch.manual_seed(0) | |
# model architecture w/ increased capacity | |
model = KANModel([1, 50, 50, 1], num_knots=15, degree=3) | |
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5) | |
loss_func = nn.MSELoss() | |
x_input = torch.linspace(0, 1, 200).unsqueeze(1) | |
target = torch.sin(2 * np.pi * x_input) # Sinusoidal target function | |
# train/val split (80/20) | |
train_size = int(0.8 * len(x_input)) | |
x_train, x_val = x_input[:train_size], x_input[train_size:] | |
y_train, y_val = target[:train_size], target[train_size:] | |
losses = [] | |
val_losses = [] | |
EPOCHS = 100 | |
for epoch in range(EPOCHS): | |
# training | |
model.train() | |
optimizer.zero_grad() | |
output = model(x_train) | |
loss = loss_func(output, y_train) | |
loss.backward() | |
# gradient clipping to prevent exploding gradients | |
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
optimizer.step() | |
losses.append(loss.item()) | |
# validation | |
model.eval() | |
with torch.no_grad(): | |
val_output = model(x_val) | |
val_loss = loss_func(val_output, y_val) | |
val_losses.append(val_loss.item()) | |
# if (epoch + 1) % 500 == 0: | |
print( | |
f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item():.6f}, Val Loss: {val_loss.item():.6f}" | |
) | |
# Training and validation loss over time | |
plt.figure() | |
plt.plot(losses, label="Training Loss") | |
plt.plot(val_losses, label="Validation Loss") | |
plt.xlabel("Epoch") | |
plt.ylabel("Loss") | |
plt.legend() | |
plt.title("Training and Validation Loss Over Time") | |
plt.show() | |
# Target function vs. model output | |
with torch.no_grad(): | |
output = model(x_input) | |
plt.figure() | |
plt.plot(x_input.numpy(), target.numpy(), label="Target Function") | |
plt.plot(x_input.numpy(), output.numpy(), label="Model Output") | |
plt.xlabel("Input (x)") | |
plt.ylabel("Output (y)") | |
plt.legend() | |
plt.title("KAN Model Approximation of sin(2πx)") | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment