Shortcuts

Source code for torch.ao.quantization.quantize

# mypy: allow-untyped-defs
import copy
import inspect
import itertools
import warnings

import torch
import torch.ao.nn.quantized as nnq
import torch.nn as nn
from torch.ao.nn.intrinsic import _FusedModule
from torch.ao.quantization.observer import _is_activation_post_process
from torch.ao.quantization.qconfig import (
    _activation_is_memoryless,
    _add_module_to_qconfig_obs_ctr,
    default_dynamic_qconfig,
    float16_dynamic_qconfig,
    float_qparams_weight_only_qconfig,
    float_qparams_weight_only_qconfig_4bit,
)
from torch.ao.quantization.quantization_mappings import (
    _get_special_act_post_process,
    _has_special_act_post_process,
    get_default_dynamic_quant_module_mappings,
    get_default_qat_module_mappings,
    get_default_qconfig_propagation_list,
    get_default_static_quant_module_mappings,
    get_default_static_quant_reference_module_mappings,
    no_observer_set,
)
from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper
from torch.nn.utils.parametrize import type_before_parametrizations

from .utils import get_qparam_dict, has_no_children_ignoring_parametrizations


__all__ = [
    "get_default_custom_config_dict",
    "propagate_qconfig_",
    "add_quant_dequant",
    "prepare",
    "quantize",
    "quantize_dynamic",
    "prepare_qat",
    "quantize_qat",
    "convert",
    "swap_module",
]


# TODO remove this once BC is no longer required to avoid a SEV
is_activation_post_process = _is_activation_post_process


_DEFAULT_CUSTOM_CONFIG_DICT = {
    "float_to_observed_custom_module_class": {
        nn.LSTM: nn.quantizable.LSTM,
        nn.MultiheadAttention: nn.quantizable.MultiheadAttention,
    },
    "observed_to_quantized_custom_module_class": {
        nn.quantizable.LSTM: nn.quantized.LSTM,
        nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention,
    },
}


def get_default_custom_config_dict():
    r"""Defines the default custom config dict."""
    return _DEFAULT_CUSTOM_CONFIG_DICT


def _propagate_qconfig_helper(
    module,
    qconfig_dict,
    qconfig_parent=None,
    prefix="",
    prepare_custom_config_dict=None,
):
    r"""This is a helper function for `propagate_qconfig_`

    Args:
        module: input module
        qconfig_dict: dictionary that maps from name of submodule to quantization
                     configuration
        qconfig_parent: quantization config of parent module, we will fallback to
                       this config when there is no specified config for current
                       module
        prefix: corresponding prefix of the current module, used as key in
                qconfig_dict
        prepare_custom_config_dict: dictionary for custom handling of modules
                                    see docs for :func:`~torch.ao.quantization.prepare_fx`

    Return:
        None, module is modified inplace with qconfig attached
    """

    module_qconfig = qconfig_dict.get(
        type_before_parametrizations(module), qconfig_parent
    )
    module_qconfig = qconfig_dict.get(prefix, module_qconfig)
    module_qconfig = getattr(module, "qconfig", module_qconfig)

    torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module)

    qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module)
    module.qconfig = qconfig_with_device_check

    for name, child in module.named_children():
        module_prefix = prefix + "." + name if prefix else name
        #  do no not propagate qconfig to child if child is non traceable
        if prepare_custom_config_dict is None or not (
            name in prepare_custom_config_dict.get("non_traceable_module_name", [])
            or type(child)
            in prepare_custom_config_dict.get("non_traceable_module_class", [])
        ):
            _propagate_qconfig_helper(
                child, qconfig_dict, qconfig_with_device_check, module_prefix
            )


[docs]def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None): r"""Propagate qconfig through the module hierarchy and assign `qconfig` attribute on each leaf module Args: module: input module qconfig_dict: dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute) prepare_custom_config_dict: dictionary for custom handling of modules see docs for :func:`~torch.ao.quantization.prepare_fx` Return: None, module is modified inplace with qconfig attached """ if qconfig_dict is None: qconfig_dict = {} if prepare_custom_config_dict is None: prepare_custom_config_dict = {} _propagate_qconfig_helper( module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict )
def _observer_forward_hook(self, input, output): r"""Forward hook that calls observer on the output""" return self.activation_post_process(output) def _observer_forward_pre_hook(self, input): r"""Forward pre hook that calls observer on the output""" return self.activation_post_process(input[0]) def _register_activation_post_process_hook(module, pre_hook=False): assert hasattr( module, "activation_post_process" ), "Expect activation_post_process attribute already attached to the module" if pre_hook: module.register_forward_pre_hook(_observer_forward_pre_hook, prepend=True) else: module.register_forward_hook(_observer_forward_hook, prepend=True) def _add_observer_( module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None, ): r"""Add observer for the leaf child of the module. This function insert observer module to all leaf child module that has a valid qconfig attribute. Args: module: input module with qconfig attributes for all the leaf modules that we want to quantize qconfig_propagation_list: a list of quantizable modules that will have observers added to them if they are leaf nodes device: parent device, if any non_leaf_module_list: list of non-leaf modules we want to add observer Return: None, module is modified inplace with added observer modules and forward_hooks """ if qconfig_propagation_list is None: qconfig_propagation_list = get_default_qconfig_propagation_list() if custom_module_class_mapping is None: custom_module_class_mapping = {} # respect device affinity when adding observers if device is None: devices = _get_unique_devices_(module) assert ( len(devices) <= 1 ), f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}" device = next(iter(devices)) if len(devices) > 0 else None def get_activation_post_process(qconfig, device, special_act_post_process=None): activation = ( qconfig.activation() if special_act_post_process is None else special_act_post_process() ) if device is not None: activation.to(device) return activation def needs_observation(m): return hasattr(m, "qconfig") and m.qconfig is not None def insert_activation_post_process(m, special_act_post_process=None): """Adds an activation post process module and register a pre or post hook that calls the module """ # We don't insert observer/fake_quantize for DeQuantStub if needs_observation(m) and not isinstance(m, DeQuantStub): # observer and hook will be gone after we swap the module m.add_module( "activation_post_process", get_activation_post_process( m.qconfig, device, special_act_post_process ), ) # Register observer as the first entry in the hook list # All post forward hooks are preserved and will be executed after the observer before convert _register_activation_post_process_hook( m, pre_hook=_activation_is_memoryless(m.qconfig) ) for name, child in module.named_children(): # TODO remove Dropout special after codebase stable if type_before_parametrizations(child) in [nn.Dropout]: continue elif issubclass( type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional) ): if needs_observation(child): assert hasattr( child, "activation_post_process" ), f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`" child.activation_post_process = get_activation_post_process( child.qconfig, device ) elif isinstance(child, _FusedModule): # activation_post_process are now added directly to nn.Sequential/_FusedModule if needs_observation(child): insert_activation_post_process(child) elif ( non_leaf_module_list is not None and type_before_parametrizations(child) in non_leaf_module_list ): if needs_observation(child): insert_activation_post_process(child) elif _has_special_act_post_process(child): special_act_post_process = _get_special_act_post_process(child) insert_activation_post_process(child, special_act_post_process) elif ( needs_observation(child) and type_before_parametrizations(child) in custom_module_class_mapping ): observed_class = custom_module_class_mapping[ type_before_parametrizations(child) ] observed_child = observed_class.from_float(child) setattr(module, name, observed_child) # TODO: These are the modules that cannot be observed # Once there are more, we should move them to a separate list if not issubclass(observed_class, tuple(no_observer_set())): insert_activation_post_process(observed_child) else: _add_observer_( child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping, ) # Insert observers only for leaf nodes, note that this observer is for # the output of the module, for input QuantStub will observe them if ( has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) and type_before_parametrizations(module) in qconfig_propagation_list ): insert_activation_post_process(module) # This is a special case for AdaRound eager mode # AdaRound contains weight_fake_quant to be propagated from API to convert # leaf node check with a number of children looks naive assumption that blocks # Adding an exception case for AdaRound if ( hasattr(module, "weight_fake_quant") and not isinstance(module, torch.nn.Sequential) and type_before_parametrizations(module) in qconfig_propagation_list ): insert_activation_post_process(module) def _get_unique_devices_(module): return {p.device for p in module.parameters() if p.device.type != "meta"} | { p.device for p in module.buffers() if p.device.type != "meta" }
[docs]def add_quant_dequant(module): r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig Note that this function will modify the children of module inplace and it can return a new module which wraps the input module as well. Args: module: input module with qconfig attributes for all the leaf modules that we want to quantize Return: Either the inplace modified module with submodules wrapped in `QuantWrapper` based on qconfig or a new `QuantWrapper` module which wraps the input module, the latter case only happens when the input module is a leaf module and we want to quantize it. """ if ( has_no_children_ignoring_parametrizations(module) and hasattr(module, "qconfig") and module.qconfig ): return QuantWrapper(module) for name, child in module.named_children(): module._modules[name] = add_quant_dequant(child) return module
[docs]def prepare( model, inplace=False, allow_list=None, observer_non_leaf_module_list=None, prepare_custom_config_dict=None, ): r"""Prepares a copy of the model for quantization calibration or quantization-aware training. Quantization configuration should be assigned preemptively to individual submodules in `.qconfig` attribute. The model will be attached with observer or fake quant modules, and qconfig will be propagated. Args: `model`: input model to be modified in-place `inplace`: carry out model transformations in-place, the original module is mutated `allow_list`: list of quantizable modules `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer `prepare_custom_config_dict`: customization configuration dictionary for prepare function .. code-block:: python # Example of prepare_custom_config_dict: prepare_custom_config_dict = { # user will manually define the corresponding observed # module class which has a from_float class method that converts # float custom module to observed custom module "float_to_observed_custom_module_class": { CustomModule: ObservedCustomModule } } """ torch._C._log_api_usage_once("quantization_api.quantize.prepare") if prepare_custom_config_dict is None: prepare_custom_config_dict = get_default_custom_config_dict() custom_module_class_mapping = prepare_custom_config_dict.get( "float_to_observed_custom_module_class", {} ) if not inplace: model = copy.deepcopy(model) # TODO: remove allow_list qconfig_propagation_list = allow_list if allow_list is None: qconfig_propagation_list = get_default_qconfig_propagation_list() propagate_qconfig_(model, qconfig_dict=None) # sanity check common API misusage if not any(hasattr(m, "qconfig") and m.qconfig for m in model.modules()): warnings.warn( "None of the submodule got qconfig applied. Make sure you " "passed correct configuration through `qconfig_dict` or " "by assigning the `.qconfig` attribute directly on submodules" ) _add_observer_( model, qconfig_propagation_list, observer_non_leaf_module_list, custom_module_class_mapping=custom_module_class_mapping, ) return model
def _remove_activation_post_process(module): # TODO: maybe we should change activation_post_process to _activation_post_process # to prevent it from being used by user if hasattr(module, "activation_post_process") and _is_activation_post_process( module.activation_post_process ): delattr(module, "activation_post_process") # remove activation_post_process pre and post hooks def remove_hooks(pre_hook=False): hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks observer_hook = ( _observer_forward_pre_hook if pre_hook else _observer_forward_hook ) handle_ids_to_remove = set() for handle_id, hook_fn in hook_map.items(): if hook_fn is observer_hook: handle_ids_to_remove.add(handle_id) for handle_id in handle_ids_to_remove: hook_map.pop(handle_id) remove_hooks(pre_hook=True) remove_hooks(pre_hook=False) # TODO: rename to something more general def _remove_qconfig(module): r"""Clean up the qconfig left in the module so that new qconfig can be propagated. Args: module: module to be cleaned up """ for child in module.children(): _remove_qconfig(child) if hasattr(module, "qconfig"): del module.qconfig _remove_activation_post_process(module)
[docs]def quantize(model, run_fn, run_args, mapping=None, inplace=False): r"""Quantize the input float model with post training static quantization. First it will prepare the model for calibration, then it calls `run_fn` which will run the calibration step, after that we will convert the model to a quantized model. Args: model: input float model run_fn: a calibration function for calibrating the prepared model run_args: positional arguments for `run_fn` inplace: carry out model transformations in-place, the original module is mutated mapping: correspondence between original module types and quantized counterparts Return: Quantized model. """ torch._C._log_api_usage_once("quantization_api.quantize.quantize") if mapping is None: mapping = get_default_static_quant_module_mappings() if not inplace: model = copy.deepcopy(model) model.eval() prepare(model, inplace=True) run_fn(model, *run_args) convert(model, mapping, inplace=True) return model
[docs]def quantize_dynamic( model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False ): r"""Converts a float model to dynamic (i.e. weights-only) quantized model. Replaces specified modules with dynamic weight-only quantized versions and output the quantized model. For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization by default is performed for layers with large weights size - i.e. Linear and RNN variants. Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`. If `qconfig` is provided, the `dtype` argument is ignored. Args: model: input model qconfig_spec: Either: - A dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute). Entries in the dictionary need to be QConfig instances. - A set of types and/or submodule names to apply dynamic quantization to, in which case the `dtype` argument is used to specify the bit-width inplace: carry out model transformations in-place, the original module is mutated mapping: maps type of a submodule to a type of corresponding dynamically quantized version with which the submodule needs to be replaced """ torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic") if qconfig_spec is None: if dtype == torch.qint8: qconfig_spec = { nn.Linear: default_dynamic_qconfig, nn.LSTM: default_dynamic_qconfig, nn.GRU: default_dynamic_qconfig, nn.LSTMCell: default_dynamic_qconfig, nn.RNNCell: default_dynamic_qconfig, nn.GRUCell: default_dynamic_qconfig, } elif dtype == torch.float16: qconfig_spec = { nn.Linear: float16_dynamic_qconfig, nn.LSTM: float16_dynamic_qconfig, nn.GRU: float16_dynamic_qconfig, nn.LSTMCell: float16_dynamic_qconfig, nn.RNNCell: float16_dynamic_qconfig, nn.GRUCell: float16_dynamic_qconfig, } elif dtype == torch.quint8: qconfig_spec = { nn.EmbeddingBag: float_qparams_weight_only_qconfig, nn.Embedding: float_qparams_weight_only_qconfig, } elif dtype == torch.quint4x2: qconfig_spec = { nn.EmbeddingBag: float_qparams_weight_only_qconfig_4bit, } else: raise ValueError( f"Don't know how to quantize with default settings for {dtype}. Provide full qconfig please" ) elif isinstance(qconfig_spec, set): if dtype is torch.qint8: default_qconfig = default_dynamic_qconfig elif dtype is torch.float16: default_qconfig = float16_dynamic_qconfig elif dtype is torch.quint8: default_qconfig = float_qparams_weight_only_qconfig elif dtype is torch.quint4x2: default_qconfig = float_qparams_weight_only_qconfig_4bit else: raise RuntimeError( "Unknown dtype specified for quantize_dynamic: ", str(dtype) ) qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig))) if mapping is None: mapping = get_default_dynamic_quant_module_mappings() if not inplace: model = copy.deepcopy(model) model.eval() propagate_qconfig_(model, qconfig_spec) convert(model, mapping, inplace=True) return model
[docs]def prepare_qat(model, mapping=None, inplace=False): r""" Prepares a copy of the model for quantization calibration or quantization-aware training and converts it to quantized version. Quantization configuration should be assigned preemptively to individual submodules in `.qconfig` attribute. Args: model: input model to be modified in-place mapping: dictionary that maps float modules to quantized modules to be replaced. inplace: carry out model transformations in-place, the original module is mutated """ torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat") assert model.training, "prepare_qat only works on models in training mode" if mapping is None: mapping = get_default_qat_module_mappings() if not inplace: model = copy.deepcopy(model) propagate_qconfig_(model, qconfig_dict=None) convert(model, mapping=mapping, inplace=True, remove_qconfig=False) prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True) return model
[docs]def quantize_qat(model, run_fn, run_args, inplace=False): r"""Do quantization aware training and output a quantized model Args: model: input model run_fn: a function for evaluating the prepared model, can be a function that simply runs the prepared model or a training loop run_args: positional arguments for `run_fn` Return: Quantized model. """ torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat") if not inplace: model = copy.deepcopy(model) model.train() prepare_qat(model, inplace=True) run_fn(model, *run_args) convert(model, inplace=True) return model
[docs]def convert( module, mapping=None, inplace=False, remove_qconfig=True, is_reference=False, convert_custom_config_dict=None, use_precomputed_fake_quant=False, ): r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class. And remove qconfig at the end if remove_qconfig is set to True. Args: `module`: prepared and calibrated module `mapping`: a dictionary that maps from source module type to target module type, can be overwritten to allow swapping user defined Modules `inplace`: carry out model transformations in-place, the original module is mutated `convert_custom_config_dict`: custom configuration dictionary for convert function `use_precomputed_fake_quant`: a flag to enable use of precomputed fake quant .. code-block:: python # Example of convert_custom_config_dict: convert_custom_config_dict = { # user will manually define the corresponding quantized # module class which has a from_observed class method that converts # observed custom module to quantized custom module "observed_to_quantized_custom_module_class": { ObservedCustomModule: QuantizedCustomModule } } """ torch._C._log_api_usage_once("quantization_api.quantize.convert") if not inplace: module = copy.deepcopy(module) _convert( module, mapping, inplace=True, is_reference=is_reference, convert_custom_config_dict=convert_custom_config_dict, use_precomputed_fake_quant=use_precomputed_fake_quant, ) if remove_qconfig: _remove_qconfig(module) return module
def _convert( module, mapping=None, inplace=False, is_reference=False, convert_custom_config_dict=None, use_precomputed_fake_quant=False, ): r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class Args: module: input module mapping: a dictionary that maps from source module type to target module type, can be overwritten to allow swapping user defined Modules inplace: carry out model transformations in-place, the original module is mutated is_reference: a flag to enable quantized reference module use_precomputed_fake_quant: a flag to enable use of precomputed fake quant """ if mapping is None: mapping = ( get_default_static_quant_reference_module_mappings() if is_reference else get_default_static_quant_module_mappings() ) if convert_custom_config_dict is None: convert_custom_config_dict = get_default_custom_config_dict() custom_module_class_mapping = convert_custom_config_dict.get( "observed_to_quantized_custom_module_class", {} ) if not inplace: module = copy.deepcopy(module) reassign = {} for name, mod in module.named_children(): # both fused modules and observed custom modules are # swapped as one unit if ( not isinstance(mod, _FusedModule) and type_before_parametrizations(mod) not in custom_module_class_mapping ): _convert( mod, mapping, True, # inplace is_reference, convert_custom_config_dict, use_precomputed_fake_quant=use_precomputed_fake_quant, ) reassign[name] = swap_module( mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant ) for key, value in reassign.items(): module._modules[key] = value return module
[docs]def swap_module( mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant=False ): r"""Swaps the module if it has a quantized counterpart and it has an `observer` attached. Args: mod: input module mapping: a dictionary that maps from nn module to nnq module Return: The corresponding quantized module of `mod` """ new_mod = mod if hasattr(mod, "qconfig") and mod.qconfig is not None: swapped = False if type_before_parametrizations(mod) in custom_module_class_mapping: new_mod = custom_module_class_mapping[ type_before_parametrizations(mod) ].from_observed(mod) swapped = True elif type_before_parametrizations(mod) in mapping: qmod = mapping[type_before_parametrizations(mod)] if hasattr(qmod, "_IS_REFERENCE") and qmod._IS_REFERENCE: assert mod.qconfig is not None weight_post_process = mod.qconfig.weight() weight_post_process(mod.weight) weight_qparams = get_qparam_dict(weight_post_process) new_mod = qmod.from_float(mod, weight_qparams) else: sig = inspect.signature(qmod.from_float) if "use_precomputed_fake_quant" in sig.parameters: new_mod = qmod.from_float( mod, use_precomputed_fake_quant=use_precomputed_fake_quant ) else: new_mod = qmod.from_float(mod) swapped = True if swapped: # Preserve module's pre forward hooks. They'll be called on quantized input for pre_hook_fn in mod._forward_pre_hooks.values(): new_mod.register_forward_pre_hook(pre_hook_fn) # Preserve module's post forward hooks except _observer_forward_hook # After convert they'll work with quantized output for hook_fn in mod._forward_hooks.values(): if hook_fn is not _observer_forward_hook: new_mod.register_forward_hook(hook_fn) # respect device affinity when swapping modules devices = _get_unique_devices_(mod) assert len(devices) <= 1 or ( len(devices) == 2 and torch.device("meta") in devices ), f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" device = next(iter(devices)) if len(devices) > 0 else None if device: new_mod.to(device) return new_mod
def _get_observer_dict(mod, target_dict, prefix=""): r"""Traverse the modules and save all observers into dict. This is mainly used for quantization accuracy debug Args: mod: the top module we want to save all observers prefix: the prefix for the current module target_dict: the dictionary used to save all the observers """ def get_prefix(prefix): return prefix if prefix == "" else prefix + "." if hasattr(mod, "activation_post_process"): target_dict[ get_prefix(prefix) + "activation_post_process" ] = mod.activation_post_process for name, child in mod.named_children(): module_prefix = get_prefix(prefix) + name if prefix else name _get_observer_dict(child, target_dict, module_prefix)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy