Skip to content

Instantly share code, notes, and snippets.

@biphasic
Created October 18, 2022 15:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save biphasic/1ecd86a0b1b347f3430eebc5c9367bc5 to your computer and use it in GitHub Desktop.
Save biphasic/1ecd86a0b1b347f3430eebc5c9367bc5 to your computer and use it in GitHub Desktop.
import sinabs
import sinabs.layers as sl
import torch
import torch.nn as nn
ann = nn.Sequential(
nn.Conv2d(1, 16, 5, bias=False),
nn.ReLU(),
nn.AvgPool2d(2),
nn.Conv2d(16, 32, 5, bias=False),
nn.ReLU(),
nn.AvgPool2d(2),
nn.Conv2d(32, 120, 4, bias=False),
nn.ReLU(),
nn.Flatten(),
nn.Linear(120, 10, bias=False),
)
# Create our SNN
num_timesteps = 100
snn = sinabs.from_torch.from_model(ann, num_timesteps=num_timesteps).spiking_model
# Create the forward hook
outputs = []
def save_outputs(module: nn.Module, input: torch.Tensor, output: torch.Tensor):
outputs.append(output)
# Attach the forward hooks
handles = []
for module in snn.modules():
if isinstance(module, sl.StatefulLayer):
handle = module.register_forward_hook(save_outputs)
handles.append(handle)
# Feed random input
batch_size, channels, height, width = 4, 1, 28, 28
rand_input = torch.rand((batch_size*num_timesteps, channels, height, width))
snn(rand_input)
# outputs will now be populated with the intermediate activations
len(outputs)
# optionally remove handles to prevent excessive memory consumption
[handle.remove() for handle in handles]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment