Skip to content

Instantly share code, notes, and snippets.

@MrRjxrby
Last active June 29, 2025 09:10
Show Gist options
  • Select an option

  • Save MrRjxrby/4765a656aeb889850aec05323765977c to your computer and use it in GitHub Desktop.

Select an option

Save MrRjxrby/4765a656aeb889850aec05323765977c to your computer and use it in GitHub Desktop.
KAN
import torch
import torch.nn as nn
import numpy as np
# Определяем инвариантные функции
class UniVariateFunction(nn.Module):
def __init__(self, output_size):
super(UniVariateFunction, self).__init__()
self.linear = nn.Linear(1, output_size)
def forward(self, x):
x = self.linear(x)
return torch.sin(x) # Используем синусоиду как функцию активации
# Определяем модель KAN
class KAN(nn.Module):
def __init__(self):
super(KAN, self).__init__()
self.phi = nn.ModuleList([UniVariateFunction(1) for _ in range(2)]) #Phi-функции для переменных x и y
self.Phi = nn.Linear(2, 1) # Phi-функция для комбинации вывода
def forward(self, x):
x1, x2 = x[:, 0], x[:, 1]
x1 = self.phi[0](x1.view(-1, 1))
x2 = self.phi[1](x2.view(-1, 1))
out = torch.cat((x1, x2), dim=1)
out = self.Phi(out)
return out
# Генерируем простой набор данных
x = torch.linspace(-np.pi, np.pi, 200)
y = torch.linspace(-np.pi, np.pi, 200)
X, Y = torch.meshgrid(x, y)
Z = torch.sin(X) + torch.cos(Y)
# Достаем "вход" модели
inputs = torch.stack([X.flatten(), Y.flatten()], dim=1)
model = KAN()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Тренируем
for epoch in range(1000):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, Z.flatten())
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print(f'Epoch {epoch}, Loss: {loss.item()}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment