Skip to content

Commit 7db3f35

Browse files
marialyutensorflower-gardener
authored andcommitted
Add int8/int16 support for SQRT op to AEQ
PiperOrigin-RevId: 785924220
1 parent 7f81be4 commit 7db3f35

File tree

10 files changed

+181
-9
lines changed

10 files changed

+181
-9
lines changed

RELEASE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
2121

22+
* `tf.lite`
23+
* Adds int8 and int16x8 support for SQRT operator.
24+
2225
### Bug Fixes and Other Changes
2326

2427
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>

tensorflow/compiler/mlir/lite/ir/tfl_ops.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LocalResponseNormalizationOp);
104104
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NegOp);
105105
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RoundOp);
106106
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SinOp);
107-
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SqrtOp);
108107
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SquareOp);
109108
// go/keep-sorted end
110109

tensorflow/compiler/mlir/lite/ir/tfl_ops.td

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3401,19 +3401,29 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
34013401

34023402
def TFL_SqrtOp: TFL_Op<"sqrt", [
34033403
Pure,
3404-
TF_SameOperandsAndResultTypeResolveRef]> {
3404+
QuantizableResult,
3405+
TFL_SameFirstOperandAndFirstResultElementType,
3406+
SameOperandsAndResultShape]> {
34053407
let summary = "Square root operator";
34063408

34073409
let description = [{
34083410
Computes element-wise Square root of input
34093411
}];
34103412

3411-
let arguments = (ins TFL_FpTensor:$x);
3413+
let arguments = (ins TFL_TensorOf<[F32, QI8, QI16]>:$x);
34123414

3413-
let results = (outs TFL_FpTensor:$y);
3415+
let results = (outs TFL_TensorOf<[F32, QI8, QI16]>:$y);
34143416

34153417
let hasFolder = 1;
34163418

3419+
let builders = [
3420+
OpBuilder<(ins "Value":$input),
3421+
[{
3422+
$_state.addOperands({input});
3423+
$_state.addTypes(input.getType());
3424+
}]>
3425+
];
3426+
34173427
let extraClassDeclaration = [{
34183428
// Returns whether the return types are compatible.
34193429
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {

tensorflow/lite/core/kernels/register.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
250250
AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL(),
251251
/* min_version = */ 1,
252252
/* max_version = */ 3);
253-
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT());
253+
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT(),
254+
/* min_version = */ 1,
255+
/* max_version = */ 2);
254256
AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT(),
255257
/* min_version = */ 1,
256258
/* max_version = */ 3);

tensorflow/lite/kernels/elementwise.cc

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ bool IsRsqrtSupportedType(const TfLiteType type) {
7070
return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
7171
}
7272

73+
bool IsSqrtSupportedType(const TfLiteType type) {
74+
return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
75+
}
76+
7377
bool IsLogSupportedType(const TfLiteType type) {
7478
return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
7579
}
@@ -354,8 +358,59 @@ TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
354358
}
355359
}
356360

361+
template <typename T>
362+
TfLiteStatus SqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node) {
363+
const TfLiteTensor* input;
364+
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
365+
TfLiteTensor* output;
366+
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
367+
368+
const auto* input_params =
369+
reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
370+
const auto* output_params =
371+
reinterpret_cast<TfLiteAffineQuantization*>(output->quantization.params);
372+
const float input_scale = input_params->scale->data[0];
373+
const int input_zp = input_params->zero_point->data[0];
374+
const float output_scale = output_params->scale->data[0];
375+
const int output_zp = output_params->zero_point->data[0];
376+
377+
const int64_t num_elements = NumElements(input);
378+
const T* in_data = GetTensorData<T>(input);
379+
T* out_data = GetTensorData<T>(output);
380+
381+
const int kMin = std::numeric_limits<T>::min();
382+
const int kMax = std::numeric_limits<T>::max();
383+
384+
for (int64_t i = 0; i < num_elements; ++i) {
385+
const float dequantized_input =
386+
input_scale * (static_cast<int>(in_data[i]) - input_zp);
387+
TF_LITE_ENSURE_MSG(context, dequantized_input >= 0.0f,
388+
"Sqrt is only defined for non-negative values");
389+
const float float_output = std::sqrt(dequantized_input);
390+
const int quantized_output =
391+
static_cast<int>(float_output / output_scale) + output_zp;
392+
out_data[i] =
393+
static_cast<T>(std::min(std::max(quantized_output, kMin), kMax));
394+
}
395+
return kTfLiteOk;
396+
}
397+
357398
TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
358-
return EvalNumeric(context, node, std::sqrt);
399+
const TfLiteTensor* input;
400+
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
401+
const TfLiteType type = input->type;
402+
switch (type) {
403+
case kTfLiteFloat32:
404+
return EvalNumeric(context, node, std::sqrt);
405+
case kTfLiteInt8:
406+
return SqrtEvalQuantized<int8_t>(context, node);
407+
case kTfLiteInt16:
408+
return SqrtEvalQuantized<int16_t>(context, node);
409+
default:
410+
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
411+
TfLiteTypeGetName(type));
412+
return kTfLiteError;
413+
}
359414
}
360415

361416
TfLiteStatus RsqrtEvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
@@ -494,10 +549,11 @@ TfLiteRegistration* Register_LOG() {
494549
return &r;
495550
}
496551

497-
GENERIC_PREPARE(PrepareSqrt, elementwise::IsNumericSupportedType, "Sqrt")
552+
GENERIC_PREPARE(PrepareSqrt, elementwise::IsSqrtSupportedType, "Sqrt")
498553

499554
TfLiteRegistration* Register_SQRT() {
500-
static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
555+
static TfLiteRegistration r = {elementwise::ElementWiseQuantizedInit,
556+
elementwise::ElementWiseQuantizedFree,
501557
PrepareSqrt, elementwise::SqrtEval};
502558
return &r;
503559
}

tensorflow/lite/kernels/elementwise_test.cc

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,91 @@ TEST(ElementWise, Sqrt) {
344344
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
345345
}
346346

347+
TEST(ElementWise, SqrtInt8) {
348+
const std::vector<float> input_data = {0, 1, 2, 9, 16, 25, 1.44, 0.5};
349+
std::vector<float> expected_output(input_data.size());
350+
for (int i = 0; i < expected_output.size(); i++) {
351+
expected_output[i] = std::sqrt(input_data[i]);
352+
}
353+
const std::vector<int> shape = {1, 8};
354+
float kInputScale = 25.0 / 255.0;
355+
float kOutputScale = 5.0 / 255.0;
356+
int32_t zero_point = -128;
357+
ElementWiseOpQuantizedModel m(
358+
BuiltinOperator_SQRT,
359+
/*input_tensor_data=*/
360+
{/*type=*/TensorType_INT8,
361+
/*shape=*/shape,
362+
/*min=*/0,
363+
/*max=*/25.0,
364+
/*scale=*/kInputScale,
365+
/*zero_point=*/zero_point,
366+
/*per_channel_quantization=*/true,
367+
/*per_channel_quantization_scales=*/{kInputScale},
368+
/*per_channel_quantization_offsets=*/{zero_point}},
369+
/*output_tensor_data=*/
370+
{/*type=*/TensorType_INT8,
371+
/*shape=*/shape,
372+
/*min=*/0,
373+
/*max=*/5.0,
374+
/*scale=*/kOutputScale,
375+
/*zero_point=*/zero_point,
376+
/*per_channel_quantization=*/true,
377+
/*per_channel_quantization_scales=*/{kOutputScale},
378+
/*per_channel_quantization_offsets=*/{zero_point}});
379+
m.QuantizeAndPopulate<int8_t>(m.input(), input_data);
380+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
381+
EXPECT_THAT(m.ExtractDequantVector<int8_t>(m.output()),
382+
ElementsAreArray(ArrayFloatNear(expected_output, kInputScale)));
383+
}
384+
385+
TEST(ElementWise, SqrtNegativeInt8) {
386+
const std::vector<float> input_data = {-1.0};
387+
float kInputScale = 1.0 / 255.0;
388+
float kOutputScale = 1.0 / 255.0;
389+
int32_t zero_point = 0;
390+
ElementWiseOpQuantizedModel m(BuiltinOperator_SQRT,
391+
{TensorType_INT8,
392+
{1, 1},
393+
0,
394+
1.0,
395+
kInputScale,
396+
zero_point,
397+
true,
398+
{kInputScale},
399+
{zero_point}},
400+
{TensorType_INT8,
401+
{1, 1},
402+
0,
403+
1.0,
404+
kOutputScale,
405+
zero_point,
406+
true,
407+
{kOutputScale},
408+
{zero_point}});
409+
m.QuantizeAndPopulate<int8_t>(m.input(), input_data);
410+
EXPECT_EQ(m.Invoke(), kTfLiteError);
411+
}
412+
413+
TEST(ElementWise, SqrtInt16) {
414+
const std::vector<float> input_data = {0, 1, 2, 9, 16, 25, 1.44, 0.5};
415+
std::vector<float> expected_output(input_data.size());
416+
for (int i = 0; i < expected_output.size(); i++) {
417+
expected_output[i] = std::sqrt(input_data[i]);
418+
}
419+
420+
const float kQuantizedTolerance = GetQuantizationStep<int16_t>(-25, 25);
421+
422+
ElementWiseOpQuantizedModel m(BuiltinOperator_SQRT,
423+
{TensorType_INT16, {1, 8}, -25, 25},
424+
{TensorType_INT16, {1, 8}, -5, 5});
425+
m.QuantizeAndPopulate<int16_t>(m.input(), input_data);
426+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
427+
EXPECT_THAT(
428+
m.ExtractDequantVector<int16_t>(m.output()),
429+
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
430+
}
431+
347432
TEST(ElementWise, Rsqrt) {
348433
ElementWiseOpFloatModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1});
349434
m.PopulateTensor<float>(m.input(), {1, 2, 4, 9});

tensorflow/lite/kernels/register_ref.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,9 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
445445
AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL(),
446446
/* min_version = */ 1,
447447
/* max_version = */ 3);
448-
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT());
448+
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT(),
449+
/* min_version = */ 1,
450+
/* max_version = */ 2);
449451
AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT(),
450452
/* min_version = */ 1,
451453
/* max_version = */ 3);

tensorflow/lite/tools/versioning/op_version.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
10451045
case BuiltinOperator_EXP:
10461046
case BuiltinOperator_LOG:
10471047
case BuiltinOperator_REDUCE_PROD:
1048+
case BuiltinOperator_SQRT:
10481049
if (op_sig.inputs.at(0).type == kTfLiteInt8 ||
10491050
op_sig.inputs.at(0).type == kTfLiteInt16) {
10501051
return 2;

tensorflow/lite/tools/versioning/op_version_test.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,4 +1438,17 @@ TEST(OpVersionTest, VersioningDynamicUpdateSliceTest) {
14381438
std::vector<TfLiteType>{kTfLiteInt16, kTfLiteInt16, kTfLiteInt32});
14391439
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
14401440
}
1441+
1442+
TEST(OpVersionTest, VersioningSqrtTest) {
1443+
OpSignature fake_op_sig = {};
1444+
fake_op_sig.op = BuiltinOperator_SQRT;
1445+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32);
1446+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
1447+
1448+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8);
1449+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
1450+
1451+
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16);
1452+
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
1453+
}
14411454
} // namespace tflite

tensorflow/lite/tools/versioning/runtime_version.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
387387
{{BuiltinOperator_LOG, 1}, "1.14.0"},
388388
{{BuiltinOperator_LOG, 2}, "2.15.0"},
389389
{{BuiltinOperator_SQRT, 1}, "1.10.0"},
390+
{{BuiltinOperator_SQRT, 2}, "2.21.0"},
390391
{{BuiltinOperator_RSQRT, 1}, "1.10.0"},
391392
{{BuiltinOperator_RSQRT, 2}, "2.5.0"},
392393
{{BuiltinOperator_RSQRT, 3}, "2.15.0"},

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy