Skip to content

Instantly share code, notes, and snippets.

@zou3519
Last active June 17, 2025 08:42
Show Gist options
  • Save zou3519/7769506acc899d83ef1464e28f22e6cf to your computer and use it in GitHub Desktop.
Save zou3519/7769506acc899d83ef1464e28f22e6cf to your computer and use it in GitHub Desktop.
import copy
def make_functional(mod, disable_autograd_tracking=False):
params_dict = dict(mod.named_parameters())
params_names = params_dict.keys()
params_values = tuple(params_dict.values())
stateless_mod = copy.deepcopy(mod)
stateless_mod.to('meta')
def fmodel(new_params_values, *args, **kwargs):
new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)
if disable_autograd_tracking:
params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
return fmodel, params_values
def make_functional_with_buffers(mod, disable_autograd_tracking=False):
params_dict = dict(mod.named_parameters())
params_names = params_dict.keys()
params_values = tuple(params_dict.values())
buffers_dict = dict(mod.named_buffers())
buffers_names = buffers_dict.keys()
buffers_values = tuple(buffers_dict.values())
stateless_mod = copy.deepcopy(mod)
stateless_mod.to('meta')
def fmodel(new_params_values, new_buffers_values, *args, **kwargs):
new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
new_buffers_dict = {name: value for name, value in zip(buffers_names, new_buffers_values)}
return torch.func.functional_call(stateless_mod, (new_params_dict, new_buffers_dict), args, kwargs)
if disable_autograd_tracking:
params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
return fmodel, params_values, buffers_values
@tranvansang
Copy link

I came to this gist from pytorch.org post: https://pytorch.org/docs/master/func.migrating.html

At line 34, isn't it stateless_mod instead of mod?

        return torch.func.functional_call(mod, (new_params_dict, new_buffers_dict), args, kwargs)

@zou3519
Copy link
Author

zou3519 commented Mar 20, 2023

I came to this gist from pytorch.org post: https://pytorch.org/docs/master/func.migrating.html

At line 34, isn't it stateless_mod instead of mod?

        return torch.func.functional_call(mod, (new_params_dict, new_buffers_dict), args, kwargs)

You're right, thanks for pointing that out. I'll update the gist.

Feel free to open an issue on GitHub if you have other questions or issues!

@PavlosPo
Copy link

PavlosPo commented Apr 3, 2024

Shouldn't be a check if new params or buffers are being inserted? If not then use the ones of the model inserted? I am trying to fine tune a Pretrained LLM Model, and I use a custom optimizer that initializes itself at the first batch of data, and I need the functional model of the model without having the params or buffers, yet. So that check makes more sense to me.

import copy

def make_functional(self, mod, new_params_values=None, disable_autograd_tracking=False):
    params_dict = dict(mod.named_parameters())
    params_names = params_dict.keys()
    params_values = tuple(params_dict.values())
    
    stateless_mod = copy.deepcopy(mod)
    stateless_mod.to('meta')

    def fmodel(new_params_values=new_params_values, *args, **kwargs):
        if new_params_values is None:
            # This is the first call to the functional model
            new_params_values = params_values
        new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
        return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)

    if disable_autograd_tracking:
        params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
    return fmodel, params_values

def make_functional_with_buffers(self, mod, new_params_values=None, new_buffers_values=None, disable_autograd_tracking=False):
    params_dict = dict(mod.named_parameters())
    params_names = params_dict.keys()
    params_values = tuple(params_dict.values())

    buffers_dict = dict(mod.named_buffers())
    buffers_names = buffers_dict.keys()
    buffers_values = tuple(buffers_dict.values())
    
    stateless_mod = copy.deepcopy(mod)
    stateless_mod.to('meta')

    def fmodel(new_params_values=new_params_values, new_buffers_values=new_buffers_values, *args, **kwargs):
        if new_params_values is None:
            # This is the first call to the functional model
            new_params_values = params_values
        if new_buffers_values is None:
            # This is the first call to the functional model
            new_buffers_values = buffers_values
        new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
        new_buffers_dict = {name: value for name, value in zip(buffers_names, new_buffers_values)}
        return torch.func.functional_call(stateless_mod, (new_params_dict, new_buffers_dict), args, kwargs)

    if disable_autograd_tracking:
        params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
    return fmodel, params_values, buffers_values

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy