Skip to content

Instantly share code, notes, and snippets.

@zou3519
Last active April 3, 2024 20:36
Show Gist options
  • Save zou3519/7769506acc899d83ef1464e28f22e6cf to your computer and use it in GitHub Desktop.
Save zou3519/7769506acc899d83ef1464e28f22e6cf to your computer and use it in GitHub Desktop.
import copy
def make_functional(mod, disable_autograd_tracking=False):
params_dict = dict(mod.named_parameters())
params_names = params_dict.keys()
params_values = tuple(params_dict.values())
stateless_mod = copy.deepcopy(mod)
stateless_mod.to('meta')
def fmodel(new_params_values, *args, **kwargs):
new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)
if disable_autograd_tracking:
params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
return fmodel, params_values
def make_functional_with_buffers(mod, disable_autograd_tracking=False):
params_dict = dict(mod.named_parameters())
params_names = params_dict.keys()
params_values = tuple(params_dict.values())
buffers_dict = dict(mod.named_buffers())
buffers_names = buffers_dict.keys()
buffers_values = tuple(buffers_dict.values())
stateless_mod = copy.deepcopy(mod)
stateless_mod.to('meta')
def fmodel(new_params_values, new_buffers_values, *args, **kwargs):
new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
new_buffers_dict = {name: value for name, value in zip(buffers_names, new_buffers_values)}
return torch.func.functional_call(stateless_mod, (new_params_dict, new_buffers_dict), args, kwargs)
if disable_autograd_tracking:
params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
return fmodel, params_values, buffers_values
@PavlosPo
Copy link

PavlosPo commented Apr 3, 2024

Shouldn't be a check if new params or buffers are being inserted? If not then use the ones of the model inserted? I am trying to fine tune a Pretrained LLM Model, and I use a custom optimizer that initializes itself at the first batch of data, and I need the functional model of the model without having the params or buffers, yet. So that check makes more sense to me.

import copy

def make_functional(self, mod, new_params_values=None, disable_autograd_tracking=False):
    params_dict = dict(mod.named_parameters())
    params_names = params_dict.keys()
    params_values = tuple(params_dict.values())
    
    stateless_mod = copy.deepcopy(mod)
    stateless_mod.to('meta')

    def fmodel(new_params_values=new_params_values, *args, **kwargs):
        if new_params_values is None:
            # This is the first call to the functional model
            new_params_values = params_values
        new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
        return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)

    if disable_autograd_tracking:
        params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
    return fmodel, params_values

def make_functional_with_buffers(self, mod, new_params_values=None, new_buffers_values=None, disable_autograd_tracking=False):
    params_dict = dict(mod.named_parameters())
    params_names = params_dict.keys()
    params_values = tuple(params_dict.values())

    buffers_dict = dict(mod.named_buffers())
    buffers_names = buffers_dict.keys()
    buffers_values = tuple(buffers_dict.values())
    
    stateless_mod = copy.deepcopy(mod)
    stateless_mod.to('meta')

    def fmodel(new_params_values=new_params_values, new_buffers_values=new_buffers_values, *args, **kwargs):
        if new_params_values is None:
            # This is the first call to the functional model
            new_params_values = params_values
        if new_buffers_values is None:
            # This is the first call to the functional model
            new_buffers_values = buffers_values
        new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
        new_buffers_dict = {name: value for name, value in zip(buffers_names, new_buffers_values)}
        return torch.func.functional_call(stateless_mod, (new_params_dict, new_buffers_dict), args, kwargs)

    if disable_autograd_tracking:
        params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
    return fmodel, params_values, buffers_values

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment