Skip to content

Commit 6093970

Browse files
majiddadashitensorflower-gardener
authored andcommitted
Add the skeleton for PropagateQsvPass in the tflite converter.
This pass will infer QSV from surrounding tensors and in cases modifies the graph to comply with the quantization constraints of the tflite kernels. PiperOrigin-RevId: 786055945
1 parent d71af31 commit 6093970

File tree

5 files changed

+104
-1
lines changed

5 files changed

+104
-1
lines changed

tensorflow/compiler/mlir/lite/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,6 +1354,7 @@ cc_library(
13541354
"transforms/prepare_quantize.cc",
13551355
"transforms/prepare_quantize_dynamic_range.cc",
13561356
"transforms/prepare_quantize_helper.cc",
1357+
"transforms/quantization/propagate_qsv_pass.cc",
13571358
"transforms/quantize.cc",
13581359
"transforms/quantize_variables.cc",
13591360
"utils/generated_op_quant_spec_getters.inc",

tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def SameOperandsAndResultsScale : OpInterface<"SameScalesOpInterface"> {
137137
];
138138

139139
let verify = [{
140-
return TFL::VerifySameScales($_op);
140+
return mlir::TFL::VerifySameScales($_op);
141141
}];
142142
}
143143

tensorflow/compiler/mlir/lite/transforms/passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ std::unique_ptr<OperationPass<func::FuncOp>> CreateDefaultQuantizePass();
120120

121121
std::unique_ptr<OperationPass<ModuleOp>> CreateLowerQuantAnnotationsPass();
122122

123+
// Creates an instance of the TFLite PropagateQsv pass which propagates scale
124+
// and zero point (QSV) information through the graph.
125+
std::unique_ptr<OperationPass<ModuleOp>> CreatePropagateQsvPass();
126+
123127
// Overloading of CreateQuantizePass which takes only necessary flags to reduce
124128
// the binary size.
125129
std::unique_ptr<OperationPass<func::FuncOp>> CreateQuantizePass(

tensorflow/compiler/mlir/lite/transforms/passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,19 @@ def LowerQuantAnnotationsPass : Pass<"tfl-lower-quant-annotations", "mlir::Modul
354354
];
355355
}
356356

357+
def PropagateQsvPass : Pass<"tfl-propagate-qsv", "mlir::ModuleOp"> {
358+
let summary = "Propagates Quantization Scale/Value (QSV) information through the graph.";
359+
let description = [{
360+
This transformation pass propagates the QSV data across operations in the
361+
TensorFlow Lite dialect.
362+
}];
363+
let constructor = "CreatePropagateQsvPass()";
364+
let dependentDialects = [
365+
"TFL::TensorFlowLiteDialect",
366+
"mlir::quant::QuantDialect"
367+
];
368+
}
369+
357370
def QuantizeVariablesPass : Pass<"tfl-quantize-variables", "mlir::ModuleOp"> {
358371
let summary = "Quantize variables";
359372
let constructor = "CreatePrepareQuantizeVariablesPass()";
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
// This transformation pass propagates QSV information through the model.
17+
18+
#include <memory>
19+
#include <utility>
20+
21+
#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project
22+
#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
23+
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
24+
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
25+
#include "mlir/IR/Diagnostics.h" // from @llvm-project
26+
#include "mlir/IR/PatternMatch.h" // from @llvm-project
27+
#include "mlir/Pass/Pass.h" // from @llvm-project
28+
#include "mlir/Support/LLVM.h" // from @llvm-project
29+
#include "mlir/Support/LogicalResult.h" // from @llvm-project
30+
#include "mlir/Support/TypeID.h" // from @llvm-project
31+
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
32+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
33+
#include "tensorflow/compiler/mlir/lite/transforms/passes.h" // IWYU pragma: keep
34+
35+
namespace mlir {
36+
namespace TFL {
37+
namespace {
38+
39+
#define GEN_PASS_DEF_PROPAGATEQSVPASS
40+
#include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
41+
42+
//===----------------------------------------------------------------------===//
43+
// Pass Definition
44+
//===----------------------------------------------------------------------===//
45+
46+
struct PropagateQsvPass : public impl::PropagateQsvPassBase<PropagateQsvPass> {
47+
public:
48+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PropagateQsvPass)
49+
50+
void runOnOperation() override;
51+
};
52+
53+
//===----------------------------------------------------------------------===//
54+
// Pass Implementation
55+
//===----------------------------------------------------------------------===//
56+
57+
void PropagateQsvPass::runOnOperation() {
58+
MLIRContext* ctx = &getContext();
59+
mlir::ModuleOp module = getOperation();
60+
61+
RewritePatternSet patterns(ctx);
62+
63+
// Configure the greedy pattern rewrite driver.
64+
GreedyRewriteConfig greedy_config;
65+
greedy_config.enableFolding(false);
66+
67+
// Apply the patterns.
68+
if (failed(
69+
applyPatternsGreedily(module, std::move(patterns), greedy_config))) {
70+
signalPassFailure();
71+
}
72+
}
73+
74+
} // namespace
75+
76+
//===----------------------------------------------------------------------===//
77+
// Pass Creation Function
78+
//===----------------------------------------------------------------------===//
79+
80+
std::unique_ptr<OperationPass<mlir::ModuleOp>> CreatePropagateQsvPass() {
81+
return std::make_unique<PropagateQsvPass>();
82+
}
83+
84+
} // namespace TFL
85+
} // namespace mlir

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy