Skip to content

Instantly share code, notes, and snippets.

@shink
Last active May 10, 2024 09:12
Show Gist options
  • Save shink/ff8e666f17dd6f7f115cae2fae8e075b to your computer and use it in GitHub Desktop.
Save shink/ff8e666f17dd6f7f115cae2fae8e075b to your computer and use it in GitHub Desktop.
if __name__ == '__main__':
from kan import *
import torch
import torchvision
# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).
model = KAN(width=[2, 5, 1], grid=5, k=3, device='cpu', seed=0)
# create dataset f(x,y) = exp(sin(pix)+y^2)
f = lambda x: torch.exp(torch.sin(torch.pi * x[:, [0]]) + x[:, [1]] ** 2)
dataset = create_dataset(f, n_var=2)
print(dataset['train_input'].shape)
print(dataset['train_label'].shape)
# plot KAN at initialization
model(dataset['train_input'])
model.plot(beta=100)
# train the model
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.)
model.plot()
model.prune()
model.plot(mask=True)
model = model.prune()
model(dataset['train_input'])
model.plot()
model.train(dataset, opt="LBFGS", steps=50)
mode = "auto" # "manual"
if mode == "manual":
# manual mode
model.fix_symbolic(0, 0, 0, 'sin')
model.fix_symbolic(0, 1, 0, 'x^2')
model.fix_symbolic(1, 0, 0, 'exp')
elif mode == "auto":
# automatic mode
lib = ['x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs']
model.auto_symbolic(lib=lib)
model.train(dataset, opt="LBFGS", steps=50) # The line of code that reported the error
model.symbolic_formula()[0][0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment