Skip to content

Instantly share code, notes, and snippets.

@InnovArul
Created November 18, 2021 21:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save InnovArul/7a3133925df1091aee462148204bb9c9 to your computer and use it in GitHub Desktop.
Save InnovArul/7a3133925df1091aee462148204bb9c9 to your computer and use it in GitHub Desktop.
import torch, torch.nn as nn
class LowlevelModule(nn.Module):
def __init__(self, custom_val):
super().__init__()
self.custom_val = custom_val
def print_custom_val(self):
print(self.custom_val.item())
class TopModel(nn.Module):
def __init__(self):
super().__init__()
self.custom_val = torch.tensor([5.])
self.m1 = LowlevelModule(self.custom_val)
self.m2 = LowlevelModule(self.custom_val)
def print_lowlevelvals(self):
self.m1.print_custom_val()
self.m2.print_custom_val()
def forward(self): pass
if __name__ == '__main__':
m = TopModel()
m.print_lowlevelvals()
# set new val
m.custom_val[0] = 10.
m.print_lowlevelvals()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment