Skip to content

sdpython/onnx-diagnostic

Repository files navigation

https://github.com/sdpython/onnx-diagnostic/raw/main/_doc/_static/logo.png

onnx-diagnostic: investigate onnx models

MIT License size https://codecov.io/gh/sdpython/onnx-diagnostic/graph/badge.svg?token=91T5ZVIP96

The main feature is about patches: it helps exporting pytorch models into ONNX, mostly designed for LLMs using dynamic caches. Patches can be enabled as follows:

from onnx_diagnostic.torch_export_patches import torch_export_patches

with torch_export_patches(patch_transformers=True) as f:
    ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
    # ...

Dynamic shapes are difficult to guess for caches, one function returns a structure defining all dimensions as dynamic. You need then to remove those which are not dynamic in your model.

from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs

dynamic_shapes = all_dynamic_shape_from_inputs(cache)

It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...). See documentation of onnx-diagnostic and torch_export_patches.

Getting started

git clone https://github.com/sdpython/onnx-diagnostic.git
cd onnx-diagnostic
pip install -e .

or

pip install onnx-diagnostic

Enlightening Examples

Where to start to export a model

Torch Export

Investigate ONNX models

Snapshot of usefuls tools

torch_export_patches

from onnx_diagnostic.torch_export_patches import torch_export_patches

with torch_export_patches(patch_transformers=True) as f:
    ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
    # ...

all_dynamic_shape_from_inputs

from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs

dynamic_shapes = all_dynamic_shape_from_inputs(cache)

torch_export_rewrite

from onnx_diagnostic.torch_export_patches import torch_export_rewrite

with torch_export_rewrite(rewrite=[Model.forward]) as f:
    ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
    # ...

string_type

import torch
from onnx_diagnostic.helpers import string_type

inputs = (
    torch.rand((3, 4), dtype=torch.float16),
    [torch.rand((5, 6), dtype=torch.float16), torch.rand((5, 6, 7), dtype=torch.float16)],
)

# with shapes
print(string_type(inputs, with_shape=True))
>>> (T10s3x4,#2[T10s5x6,T10s5x6x7])

onnx_dtype_name

import onnx
from onnx_diagnostic.helpers.onnx_helper import onnx_dtype_name

itype = onnx.TensorProto.BFLOAT16
print(onnx_dtype_name(itype))
print(onnx_dtype_name(7))
>>> BFLOAT16
>>> INT64

max_diff

import torch
from onnx_diagnostic.helpers import max_diff

print(
    max_diff(
        (torch.Tensor([1, 2]), (torch.Tensor([1, 2]),)),
        (torch.Tensor([1, 2]), (torch.Tensor([1, 2]),)),
    )
)
>>> {"abs": 0.0, "rel": 0.0, "sum": 0.0, "n": 4.0, "dnan": 0.0}s

guess_dynamic_shapes

inputs = [
    (torch.randn((5, 6)), torch.randn((1, 6))),
    (torch.randn((7, 8)), torch.randn((1, 8))),
]
ds = ModelInputs(model, inputs).guess_dynamic_shapes(auto="dim")
print(ds)
>>> (({0: 'dim_0I0', 1: 'dim_0I1'}, {1: 'dim_1I1'}), {})

About

Investigate onnx models

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 2

  •  
  •  

Languages

pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy