Skip to content

Instantly share code, notes, and snippets.

@apaszke
Last active February 28, 2023 14:28
  • Star 22 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save apaszke/4c8ead6f17a781d589f6655692e7f6f0 to your computer and use it in GitHub Desktop.
import sys
from collections import OrderedDict
PY2 = sys.version_info[0] == 2
_internal_attrs = {'_backend', '_parameters', '_buffers', '_backward_hooks', '_forward_hooks', '_forward_pre_hooks', '_modules'}
class Scope(object):
def __init__(self):
self._modules = OrderedDict()
def _make_functional(module, params_box, params_offset):
self = Scope()
num_params = len(module._parameters)
param_names = list(module._parameters.keys())
forward = type(module).forward.__func__ if PY2 else type(module).forward
for name, attr in module.__dict__.items():
if name in _internal_attrs:
continue
setattr(self, name, attr)
child_params_offset = params_offset + num_params
for name, child in module.named_children():
child_params_offset, fchild = _make_functional(child, params_box, child_params_offset)
self._modules[name] = fchild
setattr(self, name, fchild)
def fmodule(*args, **kwargs):
for name, param in zip(param_names, params_box[0][params_offset:params_offset + num_params]):
setattr(self, name, param)
return forward(self, *args, **kwargs)
return child_params_offset, fmodule
def make_functional(module):
params_box = [None]
_, fmodule_internal = _make_functional(module, params_box, 0)
def fmodule(*args, **kwargs):
params_box[0] = kwargs.pop('params')
return fmodule_internal(*args, **kwargs)
return fmodule
################################################################################
import torch
from torch import nn
from torch.nn import functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layers = nn.Sequential(
nn.Conv2d(1, 10, kernel_size=5),
nn.MaxPool2d(2),
nn.ReLU(),
nn.Conv2d(10, 20, kernel_size=5),
nn.MaxPool2d(2),
nn.ReLU(),
nn.Dropout2d())
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = self.layers(x)
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
model = Net()
model.eval()
eval_fmodel = make_functional(model)
model.train()
train_fmodel = make_functional(model)
# Verify correctness in eval mode (because we have dropout)
model.eval()
params = list(model.parameters())
x = torch.randn(10, 1, 28, 28)
print(model(x).sum())
print(fmodel(x, params=params).sum())
@Miaotxy
Copy link

Miaotxy commented Sep 11, 2019

err Msg: 'Scope' object has no attribute 'conv2d_forward'

when exce the function forward(self,...) of conv2d

guess the function forward(self.) clall self.conv2d_foraward

@AdamCobb
Copy link

I get the same error as above as well.

Is there a recent change to PyTorch that results in this and do you know how to fix it?

See error below:


AttributeError Traceback (most recent call last)
in
85 x = torch.randn(10, 1, 28, 28)
86 print(model(x).sum())
---> 87 print(train_fmodel(x, params=params).sum())

in fmodule(*args, **kwargs)
41 def fmodule(*args, **kwargs):
42 params_box[0] = kwargs.pop('params')
---> 43 return fmodule_internal(*args, **kwargs)
44
45 return fmodule

in fmodule(*args, **kwargs)
30 for name, param in zip(param_names, params_box[0][params_offset:params_offset + num_params]):
31 setattr(self, name, param)
---> 32 return forward(self, *args, **kwargs)
33
34 return child_params_offset, fmodule

in forward(self, x)
66
67 def forward(self, x):
---> 68 x = self.layers(x)
69 x = x.view(-1, 320)
70 x = F.relu(self.fc1(x))

in fmodule(*args, **kwargs)
30 for name, param in zip(param_names, params_box[0][params_offset:params_offset + num_params]):
31 setattr(self, name, param)
---> 32 return forward(self, *args, **kwargs)
33
34 return child_params_offset, fmodule

~/miniconda3/envs/hips/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
90 def forward(self, input):
91 for module in self._modules.values():
---> 92 input = module(input)
93 return input
94

in fmodule(*args, **kwargs)
30 for name, param in zip(param_names, params_box[0][params_offset:params_offset + num_params]):
31 setattr(self, name, param)
---> 32 return forward(self, *args, **kwargs)
33
34 return child_params_offset, fmodule

~/miniconda3/envs/hips/lib/python3.6/site-packages/torch/nn/modules/conv.py in forward(self, input)
343
344 def forward(self, input):
--> 345 return self.conv2d_forward(input, self.weight)
346
347 class Conv3d(_ConvNd):

AttributeError: 'Scope' object has no attribute 'conv2d_forward'

@AdamCobb
Copy link

AdamCobb commented Oct 28, 2019

If anyone else has the same issue my current fix is by adding a few lines in above as follows:
(I added comments to highlight my changes)

import sys
import types
from collections import OrderedDict

PY2 = sys.version_info[0] == 2
_internal_attrs = {'_backend', '_parameters', '_buffers', '_backward_hooks', '_forward_hooks', '_forward_pre_hooks', '_modules'}


### Had to add this for conv net
_new_methods = {'conv2d_forward'}


class Scope(object):
    def __init__(self):
        self._modules = OrderedDict()


def _make_functional(module, params_box, params_offset):
    self = Scope()
    num_params = len(module._parameters)
    param_names = list(module._parameters.keys())
    forward = type(module).forward.__func__ if PY2 else type(module).forward
    for name, attr in module.__dict__.items():
        if name in _internal_attrs:
            continue
        setattr(self, name, attr)
    ### Had to add this for conv net (MY ADDITION)
    for name in dir(module):
        if name in _new_methods:
            setattr(self, name, types.MethodType(type(module).conv2d_forward,self))
    child_params_offset = params_offset + num_params
    for name, child in module.named_children():
        child_params_offset, fchild = _make_functional(child, params_box, child_params_offset)
        self._modules[name] = fchild
        setattr(self, name, fchild)
    def fmodule(*args, **kwargs):
        for name, param in zip(param_names, params_box[0][params_offset:params_offset + num_params]):
            setattr(self, name, param)
        return forward(self, *args, **kwargs)

    return child_params_offset, fmodule


def make_functional(module):
    params_box = [None]
    _, fmodule_internal = _make_functional(module, params_box, 0)

    def fmodule(*args, **kwargs):
        params_box[0] = kwargs.pop('params')
        return fmodule_internal(*args, **kwargs)

    return fmodule

@Miaotxy
Copy link

Miaotxy commented Oct 29, 2019

@AdamCobb

you can add the conv_2d_forward to the scope of the corresponding module.

if isinstance(module, nn.Conv2d): setattr(self, "conv2d_forward", module.conv2d_forward)

@apaszke
Copy link
Author

apaszke commented Oct 29, 2019

At this point it's probably better to stick with something like higher than using this hacky script. I'm not planning to issue any more fixes here.

@Miaotxy
Copy link

Miaotxy commented Oct 30, 2019

@apaszke thx for your advices, really helpful

@AdamCobb
Copy link

Thanks very much for the suggestions and quick responses!

@swathir01
Copy link

For the code that I've attached at the end(as image), I get the following error. Please help me to get through this.
Error code

2021-03-05 13:38:45.311484: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
Traceback (most recent call last):
File "train.py", line 109, in
fire.Fire(train)
File "/usr/local/lib/python3.7/dist-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/usr/local/lib/python3.7/dist-packages/fire/core.py", line 471, in _Fire
target=component.name)
File "/usr/local/lib/python3.7/dist-packages/fire/core.py", line 681, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "train.py", line 103, in train
n_iter = model.train(dataloader, epochs, lr, wass_target, mse_weight, ttur)
File "/content/LAG-Pytorch/model.py", line 395, in train
fake = self.G(lores, eps=torch.zeros_like(eps))
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/content/LAG-Pytorch/model.py", line 173, in forward
x = self.conv1(torch.cat([x, eps], dim=1))
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(input, **kwargs)
File "/content/LAG-Pytorch/model.py", line 68, in forward
return self.activation(self.conv2d_forward(x, self.scale
self.weight))
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 779, in getattr
type(self).name, name))
torch.nn.modules.module.ModuleAttributeError: 'ScaledConv2dWithAct' object has no attribute 'conv2d_forward'

Screenshot (2)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment