Skip to content

Instantly share code, notes, and snippets.

@apaszke
Last active February 28, 2023 14:28
Show Gist options
  • Star 22 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save apaszke/4c8ead6f17a781d589f6655692e7f6f0 to your computer and use it in GitHub Desktop.
Save apaszke/4c8ead6f17a781d589f6655692e7f6f0 to your computer and use it in GitHub Desktop.
import sys
from collections import OrderedDict
PY2 = sys.version_info[0] == 2
_internal_attrs = {'_backend', '_parameters', '_buffers', '_backward_hooks', '_forward_hooks', '_forward_pre_hooks', '_modules'}
class Scope(object):
def __init__(self):
self._modules = OrderedDict()
def _make_functional(module, params_box, params_offset):
self = Scope()
num_params = len(module._parameters)
param_names = list(module._parameters.keys())
forward = type(module).forward.__func__ if PY2 else type(module).forward
for name, attr in module.__dict__.items():
if name in _internal_attrs:
continue
setattr(self, name, attr)
child_params_offset = params_offset + num_params
for name, child in module.named_children():
child_params_offset, fchild = _make_functional(child, params_box, child_params_offset)
self._modules[name] = fchild
setattr(self, name, fchild)
def fmodule(*args, **kwargs):
for name, param in zip(param_names, params_box[0][params_offset:params_offset + num_params]):
setattr(self, name, param)
return forward(self, *args, **kwargs)
return child_params_offset, fmodule
def make_functional(module):
params_box = [None]
_, fmodule_internal = _make_functional(module, params_box, 0)
def fmodule(*args, **kwargs):
params_box[0] = kwargs.pop('params')
return fmodule_internal(*args, **kwargs)
return fmodule
################################################################################
import torch
from torch import nn
from torch.nn import functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layers = nn.Sequential(
nn.Conv2d(1, 10, kernel_size=5),
nn.MaxPool2d(2),
nn.ReLU(),
nn.Conv2d(10, 20, kernel_size=5),
nn.MaxPool2d(2),
nn.ReLU(),
nn.Dropout2d())
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = self.layers(x)
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
model = Net()
model.eval()
eval_fmodel = make_functional(model)
model.train()
train_fmodel = make_functional(model)
# Verify correctness in eval mode (because we have dropout)
model.eval()
params = list(model.parameters())
x = torch.randn(10, 1, 28, 28)
print(model(x).sum())
print(fmodel(x, params=params).sum())
@Miaotxy
Copy link

Miaotxy commented Oct 30, 2019

@apaszke thx for your advices, really helpful

@AdamCobb
Copy link

Thanks very much for the suggestions and quick responses!

@swathir01
Copy link

For the code that I've attached at the end(as image), I get the following error. Please help me to get through this.
Error code

2021-03-05 13:38:45.311484: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
Traceback (most recent call last):
File "train.py", line 109, in
fire.Fire(train)
File "/usr/local/lib/python3.7/dist-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/usr/local/lib/python3.7/dist-packages/fire/core.py", line 471, in _Fire
target=component.name)
File "/usr/local/lib/python3.7/dist-packages/fire/core.py", line 681, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "train.py", line 103, in train
n_iter = model.train(dataloader, epochs, lr, wass_target, mse_weight, ttur)
File "/content/LAG-Pytorch/model.py", line 395, in train
fake = self.G(lores, eps=torch.zeros_like(eps))
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/content/LAG-Pytorch/model.py", line 173, in forward
x = self.conv1(torch.cat([x, eps], dim=1))
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(input, **kwargs)
File "/content/LAG-Pytorch/model.py", line 68, in forward
return self.activation(self.conv2d_forward(x, self.scale
self.weight))
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 779, in getattr
type(self).name, name))
torch.nn.modules.module.ModuleAttributeError: 'ScaledConv2dWithAct' object has no attribute 'conv2d_forward'

Screenshot (2)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment