-
-
Save zou3519/7769506acc899d83ef1464e28f22e6cf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
def make_functional(mod, disable_autograd_tracking=False): | |
params_dict = dict(mod.named_parameters()) | |
params_names = params_dict.keys() | |
params_values = tuple(params_dict.values()) | |
stateless_mod = copy.deepcopy(mod) | |
stateless_mod.to('meta') | |
def fmodel(new_params_values, *args, **kwargs): | |
new_params_dict = {name: value for name, value in zip(params_names, new_params_values)} | |
return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs) | |
if disable_autograd_tracking: | |
params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values) | |
return fmodel, params_values | |
def make_functional_with_buffers(mod, disable_autograd_tracking=False): | |
params_dict = dict(mod.named_parameters()) | |
params_names = params_dict.keys() | |
params_values = tuple(params_dict.values()) | |
buffers_dict = dict(mod.named_buffers()) | |
buffers_names = buffers_dict.keys() | |
buffers_values = tuple(buffers_dict.values()) | |
stateless_mod = copy.deepcopy(mod) | |
stateless_mod.to('meta') | |
def fmodel(new_params_values, new_buffers_values, *args, **kwargs): | |
new_params_dict = {name: value for name, value in zip(params_names, new_params_values)} | |
new_buffers_dict = {name: value for name, value in zip(buffers_names, new_buffers_values)} | |
return torch.func.functional_call(stateless_mod, (new_params_dict, new_buffers_dict), args, kwargs) | |
if disable_autograd_tracking: | |
params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values) | |
return fmodel, params_values, buffers_values |
I came to this gist from pytorch.org post: https://pytorch.org/docs/master/func.migrating.html
At line 34, isn't it
stateless_mod
instead ofmod
?return torch.func.functional_call(mod, (new_params_dict, new_buffers_dict), args, kwargs)
You're right, thanks for pointing that out. I'll update the gist.
Feel free to open an issue on GitHub if you have other questions or issues!
Shouldn't be a check if new params or buffers are being inserted? If not then use the ones of the model inserted? I am trying to fine tune a Pretrained LLM Model, and I use a custom optimizer that initializes itself at the first batch of data, and I need the functional model of the model without having the params or buffers, yet. So that check makes more sense to me.
import copy
def make_functional(self, mod, new_params_values=None, disable_autograd_tracking=False):
params_dict = dict(mod.named_parameters())
params_names = params_dict.keys()
params_values = tuple(params_dict.values())
stateless_mod = copy.deepcopy(mod)
stateless_mod.to('meta')
def fmodel(new_params_values=new_params_values, *args, **kwargs):
if new_params_values is None:
# This is the first call to the functional model
new_params_values = params_values
new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)
if disable_autograd_tracking:
params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
return fmodel, params_values
def make_functional_with_buffers(self, mod, new_params_values=None, new_buffers_values=None, disable_autograd_tracking=False):
params_dict = dict(mod.named_parameters())
params_names = params_dict.keys()
params_values = tuple(params_dict.values())
buffers_dict = dict(mod.named_buffers())
buffers_names = buffers_dict.keys()
buffers_values = tuple(buffers_dict.values())
stateless_mod = copy.deepcopy(mod)
stateless_mod.to('meta')
def fmodel(new_params_values=new_params_values, new_buffers_values=new_buffers_values, *args, **kwargs):
if new_params_values is None:
# This is the first call to the functional model
new_params_values = params_values
if new_buffers_values is None:
# This is the first call to the functional model
new_buffers_values = buffers_values
new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
new_buffers_dict = {name: value for name, value in zip(buffers_names, new_buffers_values)}
return torch.func.functional_call(stateless_mod, (new_params_dict, new_buffers_dict), args, kwargs)
if disable_autograd_tracking:
params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
return fmodel, params_values, buffers_values
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I came to this gist from pytorch.org post: https://pytorch.org/docs/master/func.migrating.html
At line 34, isn't it
stateless_mod
instead ofmod
?