Skip to content

Instantly share code, notes, and snippets.

@zou3519
Last active April 3, 2024 20:36
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save zou3519/7769506acc899d83ef1464e28f22e6cf to your computer and use it in GitHub Desktop.
Save zou3519/7769506acc899d83ef1464e28f22e6cf to your computer and use it in GitHub Desktop.
import copy
def make_functional(mod, disable_autograd_tracking=False):
params_dict = dict(mod.named_parameters())
params_names = params_dict.keys()
params_values = tuple(params_dict.values())
stateless_mod = copy.deepcopy(mod)
stateless_mod.to('meta')
def fmodel(new_params_values, *args, **kwargs):
new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)
if disable_autograd_tracking:
params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
return fmodel, params_values
def make_functional_with_buffers(mod, disable_autograd_tracking=False):
params_dict = dict(mod.named_parameters())
params_names = params_dict.keys()
params_values = tuple(params_dict.values())
buffers_dict = dict(mod.named_buffers())
buffers_names = buffers_dict.keys()
buffers_values = tuple(buffers_dict.values())
stateless_mod = copy.deepcopy(mod)
stateless_mod.to('meta')
def fmodel(new_params_values, new_buffers_values, *args, **kwargs):
new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
new_buffers_dict = {name: value for name, value in zip(buffers_names, new_buffers_values)}
return torch.func.functional_call(stateless_mod, (new_params_dict, new_buffers_dict), args, kwargs)
if disable_autograd_tracking:
params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
return fmodel, params_values, buffers_values
@tranvansang
Copy link

I came to this gist from pytorch.org post: https://pytorch.org/docs/master/func.migrating.html

At line 34, isn't it stateless_mod instead of mod?

        return torch.func.functional_call(mod, (new_params_dict, new_buffers_dict), args, kwargs)

@zou3519
Copy link
Author

zou3519 commented Mar 20, 2023

I came to this gist from pytorch.org post: https://pytorch.org/docs/master/func.migrating.html

At line 34, isn't it stateless_mod instead of mod?

        return torch.func.functional_call(mod, (new_params_dict, new_buffers_dict), args, kwargs)

You're right, thanks for pointing that out. I'll update the gist.

Feel free to open an issue on GitHub if you have other questions or issues!

@PavlosPo
Copy link

PavlosPo commented Apr 3, 2024

Shouldn't be a check if new params or buffers are being inserted? If not then use the ones of the model inserted? I am trying to fine tune a Pretrained LLM Model, and I use a custom optimizer that initializes itself at the first batch of data, and I need the functional model of the model without having the params or buffers, yet. So that check makes more sense to me.

import copy

def make_functional(self, mod, new_params_values=None, disable_autograd_tracking=False):
    params_dict = dict(mod.named_parameters())
    params_names = params_dict.keys()
    params_values = tuple(params_dict.values())
    
    stateless_mod = copy.deepcopy(mod)
    stateless_mod.to('meta')

    def fmodel(new_params_values=new_params_values, *args, **kwargs):
        if new_params_values is None:
            # This is the first call to the functional model
            new_params_values = params_values
        new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
        return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)

    if disable_autograd_tracking:
        params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
    return fmodel, params_values

def make_functional_with_buffers(self, mod, new_params_values=None, new_buffers_values=None, disable_autograd_tracking=False):
    params_dict = dict(mod.named_parameters())
    params_names = params_dict.keys()
    params_values = tuple(params_dict.values())

    buffers_dict = dict(mod.named_buffers())
    buffers_names = buffers_dict.keys()
    buffers_values = tuple(buffers_dict.values())
    
    stateless_mod = copy.deepcopy(mod)
    stateless_mod.to('meta')

    def fmodel(new_params_values=new_params_values, new_buffers_values=new_buffers_values, *args, **kwargs):
        if new_params_values is None:
            # This is the first call to the functional model
            new_params_values = params_values
        if new_buffers_values is None:
            # This is the first call to the functional model
            new_buffers_values = buffers_values
        new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
        new_buffers_dict = {name: value for name, value in zip(buffers_names, new_buffers_values)}
        return torch.func.functional_call(stateless_mod, (new_params_dict, new_buffers_dict), args, kwargs)

    if disable_autograd_tracking:
        params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
    return fmodel, params_values, buffers_values

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment