Created
June 26, 2025 16:26
-
-
Save MrRjxrby/e467320ae2d5a3078358b6f07595b931 to your computer and use it in GitHub Desktop.
Capsule
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import datasets, transforms | |
| from torch.utils.data import DataLoader | |
| import numpy as np | |
| class CapsuleLayer(nn.Module): | |
| def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, kernel_size=None, stride=None): | |
| super(CapsuleLayer, self).__init__() | |
| self.num_route_nodes = num_route_nodes | |
| self.num_capsules = num_capsules | |
| if num_route_nodes != -1: | |
| self.route_weights = nn.Parameter(torch.randn(num_capsules, num_route_nodes, in_channels, out_channels)) | |
| else: | |
| self.capsules = nn.ModuleList([ | |
| nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0) | |
| for _ in range(num_capsules) | |
| ]) | |
| def squash(self, tensor, dim=-1): | |
| squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True) | |
| scale = squared_norm / (1 + squared_norm) | |
| return scale * tensor / torch.sqrt(squared_norm) | |
| def forward(self, x): | |
| if self.num_route_nodes != -1: | |
| priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :] | |
| logits = torch.zeros(*priors.size()).to(x.device) | |
| for i in range(3): # Динамическая маршрутизация за 3 итерации | |
| probs = F.softmax(logits, dim=2) | |
| outputs = self.squash((probs * priors).sum(dim=2, keepdim=True)) | |
| if i != 2: | |
| delta_logits = (priors * outputs).sum(dim=-1, keepdim=True) | |
| logits = logits + delta_logits | |
| return outputs.squeeze() | |
| else: | |
| outputs = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules] | |
| outputs = torch.cat(outputs, dim=-1) | |
| return self.squash(outputs) | |
| class CapsuleNet(nn.Module): | |
| def __init__(self): | |
| super(CapsuleNet, self).__init__() | |
| self.conv1 = nn.Conv2d(1, 256, kernel_size=9, stride=1) | |
| self.primary_capsules = CapsuleLayer(num_capsules=8, num_route_nodes=-1, in_channels=256, out_channels=32, kernel_size=9, stride=2) | |
| self.digit_capsules = CapsuleLayer(num_capsules=10, num_route_nodes=32*6*6, in_channels=8, out_channels=16) | |
| self.decoder = nn.Sequential( | |
| nn.Linear(16*10, 512), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(512, 1024), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(1024, 784), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x, y=None): | |
| x = F.relu(self.conv1(x), inplace=True) | |
| x = self.primary_capsules(x) | |
| x = self.digit_capsules(x).squeeze().transpose(0, 1) | |
| classes = (x ** 2).sum(dim=-1) ** 0.5 | |
| classes = F.softmax(classes, dim=-1) | |
| if y is None: | |
| _, max_length_indices = classes.max(dim=1) | |
| y = torch.eye(10).to(x.device).index_select(dim=0, index=max_length_indices) | |
| reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1)) | |
| return classes, reconstructions | |
| def caps_loss(y_true, y_pred, x, x_recon, lam_recon=0.0005): | |
| margin_loss = y_true * torch.clamp(0.9 - y_pred, min=0.) ** 2 + 0.5 * (1. - y_true) * torch.clamp(y_pred - 0.1, min=0.) ** 2 | |
| margin_loss = margin_loss.sum() | |
| reconstruction_loss = F.mse_loss(x_recon, x.view(x_recon.size())) | |
| return margin_loss + lam_recon * reconstruction_loss | |
| def train(model, train_loader, optimizer, epoch): | |
| model.train() | |
| train_loss = 0 | |
| for batch_idx, (data, target) in enumerate(train_loader): | |
| data, target = data.to(device), target.to(device) | |
| optimizer.zero_grad() | |
| y = torch.eye(10).to(device).index_select(dim=0, index=target) | |
| classes, reconstructions = model(data, y) | |
| loss = caps_loss(y, classes, data, reconstructions) | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| if batch_idx % 100 == 0: | |
| print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}' | |
| f' ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}') | |
| return train_loss / len(train_loader) | |
| def test(model, test_loader): | |
| model.eval() | |
| test_loss = 0 | |
| correct = 0 | |
| with torch.no_grad(): | |
| for data, target in test_loader: | |
| data, target = data.to(device), target.to(device) | |
| y = torch.eye(10).to(device).index_select(dim=0, index=target) | |
| classes, reconstructions = model(data, y) | |
| test_loss += caps_loss(y, classes, data, reconstructions).item() | |
| pred = classes.max(1)[1] | |
| correct += pred.eq(target).sum().item() | |
| test_loss /= len(test_loader) | |
| accuracy = 100. * correct / len(test_loader.dataset) | |
| print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)}' | |
| f' ({accuracy:.1f}%)\n') | |
| return test_loss, accuracy | |
| if __name__ == "__main__": | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Параметры обучения | |
| batch_size = 128 | |
| epochs = 30 | |
| learning_rate = 0.001 | |
| # Загрузка данных MNIST | |
| train_loader = DataLoader( | |
| datasets.MNIST('../data', train=True, download=True, | |
| transform=transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ])), | |
| batch_size=batch_size, shuffle=True) | |
| test_loader = DataLoader( | |
| datasets.MNIST('../data', train=False, transform=transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ])), | |
| batch_size=batch_size, shuffle=True) | |
| # Инициализация модели | |
| model = CapsuleNet().to(device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
| # Обучение | |
| for epoch in range(1, epochs + 1): | |
| train_loss = train(model, train_loader, optimizer, epoch) | |
| test_loss, test_acc = test(model, test_loader) | |
| # Сохранение модели | |
| torch.save(model.state_dict(), 'capsnet_model.pth') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment