Skip to content

Instantly share code, notes, and snippets.

@MrRjxrby
Created June 26, 2025 16:26
Show Gist options
  • Select an option

  • Save MrRjxrby/e467320ae2d5a3078358b6f07595b931 to your computer and use it in GitHub Desktop.

Select an option

Save MrRjxrby/e467320ae2d5a3078358b6f07595b931 to your computer and use it in GitHub Desktop.
Capsule
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
class CapsuleLayer(nn.Module):
def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, kernel_size=None, stride=None):
super(CapsuleLayer, self).__init__()
self.num_route_nodes = num_route_nodes
self.num_capsules = num_capsules
if num_route_nodes != -1:
self.route_weights = nn.Parameter(torch.randn(num_capsules, num_route_nodes, in_channels, out_channels))
else:
self.capsules = nn.ModuleList([
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0)
for _ in range(num_capsules)
])
def squash(self, tensor, dim=-1):
squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
scale = squared_norm / (1 + squared_norm)
return scale * tensor / torch.sqrt(squared_norm)
def forward(self, x):
if self.num_route_nodes != -1:
priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :]
logits = torch.zeros(*priors.size()).to(x.device)
for i in range(3): # Динамическая маршрутизация за 3 итерации
probs = F.softmax(logits, dim=2)
outputs = self.squash((probs * priors).sum(dim=2, keepdim=True))
if i != 2:
delta_logits = (priors * outputs).sum(dim=-1, keepdim=True)
logits = logits + delta_logits
return outputs.squeeze()
else:
outputs = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules]
outputs = torch.cat(outputs, dim=-1)
return self.squash(outputs)
class CapsuleNet(nn.Module):
def __init__(self):
super(CapsuleNet, self).__init__()
self.conv1 = nn.Conv2d(1, 256, kernel_size=9, stride=1)
self.primary_capsules = CapsuleLayer(num_capsules=8, num_route_nodes=-1, in_channels=256, out_channels=32, kernel_size=9, stride=2)
self.digit_capsules = CapsuleLayer(num_capsules=10, num_route_nodes=32*6*6, in_channels=8, out_channels=16)
self.decoder = nn.Sequential(
nn.Linear(16*10, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 784),
nn.Sigmoid()
)
def forward(self, x, y=None):
x = F.relu(self.conv1(x), inplace=True)
x = self.primary_capsules(x)
x = self.digit_capsules(x).squeeze().transpose(0, 1)
classes = (x ** 2).sum(dim=-1) ** 0.5
classes = F.softmax(classes, dim=-1)
if y is None:
_, max_length_indices = classes.max(dim=1)
y = torch.eye(10).to(x.device).index_select(dim=0, index=max_length_indices)
reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1))
return classes, reconstructions
def caps_loss(y_true, y_pred, x, x_recon, lam_recon=0.0005):
margin_loss = y_true * torch.clamp(0.9 - y_pred, min=0.) ** 2 + 0.5 * (1. - y_true) * torch.clamp(y_pred - 0.1, min=0.) ** 2
margin_loss = margin_loss.sum()
reconstruction_loss = F.mse_loss(x_recon, x.view(x_recon.size()))
return margin_loss + lam_recon * reconstruction_loss
def train(model, train_loader, optimizer, epoch):
model.train()
train_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
y = torch.eye(10).to(device).index_select(dim=0, index=target)
classes, reconstructions = model(data, y)
loss = caps_loss(y, classes, data, reconstructions)
loss.backward()
optimizer.step()
train_loss += loss.item()
if batch_idx % 100 == 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}'
f' ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
return train_loss / len(train_loader)
def test(model, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
y = torch.eye(10).to(device).index_select(dim=0, index=target)
classes, reconstructions = model(data, y)
test_loss += caps_loss(y, classes, data, reconstructions).item()
pred = classes.max(1)[1]
correct += pred.eq(target).sum().item()
test_loss /= len(test_loader)
accuracy = 100. * correct / len(test_loader.dataset)
print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)}'
f' ({accuracy:.1f}%)\n')
return test_loss, accuracy
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Параметры обучения
batch_size = 128
epochs = 30
learning_rate = 0.001
# Загрузка данных MNIST
train_loader = DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
test_loader = DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
# Инициализация модели
model = CapsuleNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Обучение
for epoch in range(1, epochs + 1):
train_loss = train(model, train_loader, optimizer, epoch)
test_loss, test_acc = test(model, test_loader)
# Сохранение модели
torch.save(model.state_dict(), 'capsnet_model.pth')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment