Skip to content

Instantly share code, notes, and snippets.

@ppwwyyxx
Last active July 1, 2022 10:50
Show Gist options
  • Save ppwwyyxx/1465ded381477c06aa0fd3052069a8e7 to your computer and use it in GitHub Desktop.
Save ppwwyyxx/1465ded381477c06aa0fd3052069a8e7 to your computer and use it in GitHub Desktop.
from torch import nn
import torch
class ModuleWithLazySubmodules(nn.Module):
def __init__(self, in_dim, middle_dim, submodules):
super().__init__()
self.first_layer = nn.Linear(in_dim, middle_dim)
self.submodules = nn.Sequential(*submodules)
self.forward(torch.rand(1, in_dim))
def forward(self, input):
out = self.first_layer(input)
return self.submodules(out)
# Mix lazy and non-lazy initialized modules
mod = ModuleWithLazySubmodules(
5, 10,
[nn.LazyLinear(20), nn.LazyLinear(30), nn.LazyLinear(40)]
)
print(mod) # Initialized automatically, as long as the top-level module is non-lazy
"""
ModuleWithLazySubmodules(
(first_layer): Linear(in_features=5, out_features=10, bias=True)
(submodules): Sequential(
(0): Linear(in_features=10, out_features=20, bias=True)
(1): Linear(in_features=20, out_features=30, bias=True)
(2): Linear(in_features=30, out_features=40, bias=True)
)
)
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment