diff --git a/.github/workflows/ccpp.yml b/.github/workflows/ccpp.yml deleted file mode 100644 index e3233268..00000000 --- a/.github/workflows/ccpp.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: C/C++ CI - -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] - -jobs: - build: - - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v2 - - name: configure - run: ./configure - - name: make - run: make - - name: make check - run: make check - - name: make distcheck - run: make distcheck diff --git a/python/mrt/pytorch/__init__.py b/python/mrt/pytorch/__init__.py new file mode 100644 index 00000000..f44caf1b --- /dev/null +++ b/python/mrt/pytorch/__init__.py @@ -0,0 +1,3 @@ +__version__='1.3.1' +from .util import fuse_model +from .util import post_training_quant \ No newline at end of file diff --git a/python/mrt/pytorch/rules/__init__.py b/python/mrt/pytorch/rules/__init__.py new file mode 100644 index 00000000..b4ee9bb6 --- /dev/null +++ b/python/mrt/pytorch/rules/__init__.py @@ -0,0 +1 @@ +from .fuse_rule import FuseRule \ No newline at end of file diff --git a/python/mrt/pytorch/rules/convbn.py b/python/mrt/pytorch/rules/convbn.py new file mode 100644 index 00000000..7c789ec3 --- /dev/null +++ b/python/mrt/pytorch/rules/convbn.py @@ -0,0 +1,54 @@ +from . import FuseRule + +import re + +_IDLE = 0 +_CONV = 1 + +class ConvBNRuleByName(FuseRule): + """Rule for searching plain BN after Convolution by default pytorch-generated name. + Equivalence transformation with out loss in accuracy.""" + + @staticmethod + def getInfo(keyword, name): + res = re.match(f"(.*){keyword}(\d*)", name) + if res is None: + return None, None, None + return res.group(), res.group(1), res.group(2) + + def __init__(self): + self._names_lists = list() + self._idle() + + def _idle(self): + self.cur_list = list() + self.cur_prefix = None + self.cur_suffix = None + self.state = _IDLE + + def add_module(self, m): + m_name = m[0] + if self.state == _IDLE: + check, prefix, suffix = self.getInfo("conv", m_name) + if check is not None: + self.cur_prefix = prefix + self.cur_suffix = suffix + self.cur_list = [m_name] + self.state = _CONV + elif self.state == _CONV: + check, prefix, suffix = self.getInfo("bn", m_name) + if check is None: + self._idle() + elif prefix != self.cur_prefix: + self._idle() + elif suffix == self.cur_suffix: + self.cur_list.append(m_name) + self._names_lists.append(self.cur_list) + self._idle() + else: + self._idle() + else: + raise NotImplementedError + + def names_lists(self): + return self._names_lists \ No newline at end of file diff --git a/python/mrt/pytorch/rules/downsample.py b/python/mrt/pytorch/rules/downsample.py new file mode 100644 index 00000000..d4e2cf71 --- /dev/null +++ b/python/mrt/pytorch/rules/downsample.py @@ -0,0 +1,52 @@ +from . import FuseRule + +import re + +_IDLE = 0 +_DOWNS = 1 + +class DownsampleByName(FuseRule): + """Rule for searching Downsample(BN, Convolution) by default pytorch-generated name . + Equivalence transformation with out loss in accuracy.""" + + @staticmethod + def getInfo(keyword, name): + res = re.match(f"(.*){keyword}(\d*)", name) + if res is None: + return None, None, None + return res.group(), res.group(1), res.group(2) + + def __init__(self): + self._names_lists = list() + self._idle() + + def _idle(self): + self.cur_list = list() + self.cur_prefix = None + self.cur_suffix = None + self.state = _IDLE + + def add_module(self, m): + m_name = m[0] + if self.state == _IDLE: + check, prefix, suffix = self.getInfo("downsample.", m_name) + if check is not None: + self.cur_prefix = prefix + self.cur_suffix = suffix + self.cur_list = [m_name] + self.state = _DOWNS + elif self.state == _DOWNS: + check, prefix, suffix = self.getInfo("downsample.", m_name) + if check is None: + self._idle() + elif prefix != self.cur_prefix: + self._idle() + else: + self.cur_list.append(m_name) + self._names_lists.append(self.cur_list) + self._idle() + else: + raise NotImplementedError + + def names_lists(self): + return self._names_lists diff --git a/python/mrt/pytorch/rules/fuse_rule.py b/python/mrt/pytorch/rules/fuse_rule.py new file mode 100644 index 00000000..fc3d3b02 --- /dev/null +++ b/python/mrt/pytorch/rules/fuse_rule.py @@ -0,0 +1,26 @@ +import torch.nn as nn + +class FuseRule(object): + """Template class of rules for model fusing.""" + def __init__(self): + """ + State parameters here. + """ + pass + + def add_module(self, m): + """ + Args: + m: item from .named_modules() + Returns: + None + """ + assert type(m) is nn.Module + raise NotImplementedError + + def names_lists(self): + """ + Returns: + names_lists: A list of name list + """ + raise diff --git a/python/mrt/pytorch/util.py b/python/mrt/pytorch/util.py new file mode 100644 index 00000000..3b9f2c37 --- /dev/null +++ b/python/mrt/pytorch/util.py @@ -0,0 +1,57 @@ +from .rules import FuseRule +import torch.nn as nn +import torch.quantization as quant + + +def fuse_model(model, rules, inplace=True): + """Fuse the model with a list of rules. + Args: + model: A nn.Module to be fused. + rules: A list of rule object functions as FuseRule. + inplace: Bool. If True, the model object will be modified. + Return: + A new fused model, if inplace is False. + """ + assert isinstance(model, nn.Module) + + for m in model.named_modules(): + for rule in rules: + rule.add_module(m) + modules_to_fuse = list() + for rule in rules: + modules_to_fuse += rule.names_lists() + print(modules_to_fuse) + model = quant.fuse_modules(model, modules_to_fuse, inplace=inplace) + return model + + +def post_training_quant(fused_model, data_loader=None, batches=None, inplace=True): + """Quant a trained model (int8). + Args: + model: A fused nn.Module to be quanted. + data_loader: A data loader provides input data iterations. + batches: The limitation of iteration(batch) number. + inplace: Bool. If True, the model object will be modified. + Return: + A new quanted model, if inplace is False. + """ + if not hasattr(fused_model, "qconfig"): + fused_model.qconfig = quant.get_default_qconfig('fbgemm') + if inplace: + quant.prepare(fused_model, inplace=True) + else: + fused_model = quant.prepare(fused_model, inplace=False) + + if data_loader is not None: + for idx, _in in enumerate(data_loader): + if batches is not None: + print(f'\r{idx} / {batches}', end='') + output = fused_model(_in) + if batches is not None and idx >= batches: + break + if inplace: + quant.convert(fused_model, inplace=True) + else: + fused_model = quant.convert(fused_model, inplace=False) + + return fused_model \ No newline at end of file diff --git a/python/mrt/yamrt/quant/quantize.py b/python/mrt/yamrt/quant/quantize.py index 5b05569c..03be85d7 100644 --- a/python/mrt/yamrt/quant/quantize.py +++ b/python/mrt/yamrt/quant/quantize.py @@ -106,7 +106,7 @@ def quantize_static(model_input, model = load_model(Path(model_input), optimize_model) - calibrator = get_calibrator(model, op_types_to_quantize, calibrate_method=calibrate_method) + calibrator = get_calibrator(model, op_types_to_quantize, method=calibrate_method) calibrator.collect_data(calibration_data_reader) tensors_range = calibrator.compute_range() diff --git a/tests/yamrt/infer.py b/tests/yamrt/infer.py index f34078cc..00a53191 100644 --- a/tests/yamrt/infer.py +++ b/tests/yamrt/infer.py @@ -46,7 +46,7 @@ def softmax(x): for i in range(data.shape[0]): norm_data[i,:,:] = (data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i] norm_data = norm_data.reshape(1, 3, 224, 224).astype('float32') - + result = session.run([output_name],{input_name:norm_data}) res = softmax(np.array(result)).tolist() idx = np.argmax(res) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy