Skip to content

Commit b3527c4

Browse files
seantaltstensorflower-gardener
authored andcommitted
[XLA:CPU] Lower to xla.rsqrt for all rsqrt ops.
Plumb a target_machine through to SupportedVectorTypes and CreateDefinition for rsqrt. Fall back to 1/sqrt(x) if no avx, or f64 without avx512f. PiperOrigin-RevId: 785605830
1 parent 5c43408 commit b3527c4

File tree

19 files changed

+415
-79
lines changed

19 files changed

+415
-79
lines changed

third_party/xla/xla/backends/cpu/codegen/ir_compiler.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ llvm::Error IrCompiler::RunIrPasses(llvm::Module& module,
328328
std::make_unique<llvm::TargetLibraryInfoImpl>(target_triple);
329329
target_library_info_impl->addVectorizableFunctions(
330330
PolynomialApproximationsVectorization());
331-
codegen::MathFunctionLib math_lib;
331+
codegen::MathFunctionLib math_lib(target_machine);
332332
target_library_info_impl->addVectorizableFunctions(math_lib.Vectorizations());
333333

334334
fam.registerPass(

third_party/xla/xla/codegen/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ cc_library(
192192
"//xla/codegen/math:intrinsic",
193193
"//xla/codegen/math:ldexp",
194194
"//xla/codegen/math:log1p",
195+
"//xla/codegen/math:rsqrt",
195196
"//xla/codegen/math:string_interner",
196197
"//xla/codegen/math:vec_name_mangler",
197198
"//xla/service/llvm_ir:llvm_util",

third_party/xla/xla/codegen/emitters/transforms/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ cc_library(
7979
"//xla/codegen/math:fptrunc",
8080
"//xla/codegen/math:intrinsic",
8181
"//xla/codegen/math:log1p",
82+
"//xla/codegen/math:rsqrt",
8283
"//xla/hlo/analysis:indexing_analysis",
8384
"//xla/mlir_hlo",
8485
"//xla/mlir_hlo:map_mhlo_to_scalar_op",

third_party/xla/xla/codegen/emitters/transforms/lower_xla_math_lib.cc

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ limitations under the License.
3737
#include "xla/codegen/math/fptrunc.h"
3838
#include "xla/codegen/math/intrinsic.h"
3939
#include "xla/codegen/math/log1p.h"
40+
#include "xla/codegen/math/rsqrt.h"
4041

4142
namespace xla {
4243
namespace emitters {
@@ -216,6 +217,32 @@ class LowerTruncF32BF16FPattern
216217
mlir::ModuleOp& module_op_;
217218
};
218219

220+
class RsqrtPattern : public mlir::OpRewritePattern<mlir::math::RsqrtOp> {
221+
public:
222+
RsqrtPattern(mlir::MLIRContext* context, mlir::ModuleOp& module_op)
223+
: OpRewritePattern(context), module_op_(module_op) {}
224+
225+
mlir::LogicalResult matchAndRewrite(
226+
mlir::math::RsqrtOp op, mlir::PatternRewriter& rewriter) const override {
227+
// Don't change if not f32 or f64:
228+
auto src_type = op.getOperand().getType();
229+
if (!mlir::isa<mlir::Float32Type>(src_type) &&
230+
!mlir::isa<mlir::Float64Type>(src_type)) {
231+
return rewriter.notifyMatchFailure(op, "Not f32 or f64");
232+
}
233+
234+
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
235+
auto rsqrt_decl = codegen::intrinsics::Rsqrt::GetOrInsertDeclaration(
236+
rewriter, module_op_, Type::TypeFromIrType(op.getOperand().getType()));
237+
auto call_op = b.create<mlir::func::CallOp>(rsqrt_decl, op.getOperand());
238+
rewriter.replaceOp(op, call_op->getResults());
239+
return mlir::success();
240+
}
241+
242+
private:
243+
mlir::ModuleOp& module_op_;
244+
};
245+
219246
class LowerXlaMathLibPass
220247
: public impl::LowerXlaMathLibPassBase<LowerXlaMathLibPass> {
221248
public:
@@ -226,7 +253,8 @@ class LowerXlaMathLibPass
226253
mlir::ModuleOp module_op = getOperation();
227254
mlir::RewritePatternSet patterns(&getContext());
228255
patterns.add<LowerExpOpPattern, LowerLog1pPattern, LowerErfPattern,
229-
LowerTruncF32BF16FPattern>(&getContext(), module_op);
256+
LowerTruncF32BF16FPattern, RsqrtPattern>(&getContext(),
257+
module_op);
230258

231259
if (mlir::failed(
232260
mlir::applyPatternsGreedily(module_op, std::move(patterns)))) {

third_party/xla/xla/codegen/emitters/transforms/tests/lower_xla_math_lib.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,32 @@ module {
7878
// CHECK-NOT: math.erf
7979
// CHECK: %[[ERF_CALL:.*]] = call @erf
8080
// CHECK: return %[[ERF_CALL]]
81+
82+
83+
// -----
84+
85+
module {
86+
func.func @rsqrt(%arg0: f32) -> f32 {
87+
%ret = math.rsqrt %arg0 : f32
88+
return %ret : f32
89+
}
90+
}
91+
92+
// CHECK-LABEL: @local_xla.rsqrt.f32
93+
// CHECK-NOT: math.rsqrt
94+
// CHECK: %[[RSQRT_CALL:.*]] = call @local_xla.rsqrt.f32
95+
// CHECK: return %[[RSQRT_CALL]]
96+
97+
// -----
98+
99+
module {
100+
func.func @rsqrt(%arg0: f64) -> f64 {
101+
%ret = math.rsqrt %arg0 : f64
102+
return %ret : f64
103+
}
104+
}
105+
106+
// CHECK-LABEL: @local_xla.rsqrt.f64
107+
// CHECK-NOT: math.rsqrt
108+
// CHECK: %[[RSQRT_CALL:.*]] = call @local_xla.rsqrt.f64
109+
// CHECK: return %[[RSQRT_CALL]]

third_party/xla/xla/codegen/math/BUILD

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ cc_library(
3535
"@com_google_absl//absl/status:statusor",
3636
"@com_google_absl//absl/strings",
3737
"@llvm-project//llvm:Support",
38+
"@llvm-project//llvm:Target",
3839
"@llvm-project//llvm:ir_headers",
3940
"@llvm-project//mlir:FuncDialect",
4041
"@llvm-project//mlir:IR",
@@ -125,14 +126,18 @@ cc_library(
125126
"//xla/service/cpu:orc_jit_memory_mapper",
126127
"//xla/service/llvm_ir:llvm_util",
127128
"//xla/tsl/util:safe_reinterpret_cast",
129+
"@com_google_absl//absl/base",
128130
"@com_google_absl//absl/base:dynamic_annotations",
129131
"@com_google_absl//absl/container:flat_hash_map",
130132
"@com_google_absl//absl/log",
131133
"@com_google_absl//absl/log:check",
132134
"@llvm-project//llvm:ExecutionEngine",
133135
"@llvm-project//llvm:JITLink",
136+
"@llvm-project//llvm:MC",
134137
"@llvm-project//llvm:OrcJIT", # buildcleaner: keep
135138
"@llvm-project//llvm:Support",
139+
"@llvm-project//llvm:Target",
140+
"@llvm-project//llvm:TargetParser",
136141
"@llvm-project//llvm:ir_headers",
137142
] + if_llvm_aarch64_available([
138143
"@llvm-project//llvm:AArch64AsmParser", # fixdeps: keep
@@ -306,14 +311,13 @@ cc_library(
306311
deps = [
307312
":intrinsic",
308313
"//xla:xla_data_proto_cc",
309-
"//xla/service/llvm_ir:llvm_util",
310314
"@com_google_absl//absl/log",
311315
"@com_google_absl//absl/log:check",
312-
"@com_google_absl//absl/status",
313316
"@com_google_absl//absl/status:statusor",
314317
"@com_google_absl//absl/strings",
315318
"@llvm-project//llvm:Core", # buildcleaner: keep
316319
"@llvm-project//llvm:Support",
320+
"@llvm-project//llvm:Target",
317321
"@llvm-project//llvm:ir_headers",
318322
],
319323
)
@@ -358,7 +362,9 @@ xla_cc_test(
358362
":test_matchers",
359363
"//xla:shape_util",
360364
"//xla:xla_data_proto_cc",
365+
"@com_google_absl//absl/log",
361366
"@com_google_googletest//:gtest_main",
367+
"@eigen_archive//:eigen3",
362368
"@llvm-project//llvm:JITLink",
363369
"@llvm-project//llvm:Support",
364370
"@llvm-project//llvm:TargetParser",
@@ -379,6 +385,7 @@ xla_cc_test(
379385
"//xla:xla_data_proto_cc",
380386
"//xla/tsl/platform:test_benchmark",
381387
"//xla/tsl/platform:test_main",
388+
"@llvm-project//llvm:Target",
382389
"@llvm-project//llvm:ir_headers",
383390
],
384391
)

third_party/xla/xla/codegen/math/intrinsic.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ limitations under the License.
2929
#include "absl/strings/str_join.h"
3030
#include "llvm/IR/Function.h"
3131
#include "llvm/IR/Module.h"
32+
#include "llvm/Target/TargetMachine.h"
3233
#include "mlir/Dialect/Func/IR/FuncOps.h"
3334
#include "mlir/IR/Builders.h"
3435
#include "mlir/IR/BuiltinOps.h"

third_party/xla/xla/codegen/math/rsqrt.cc

Lines changed: 76 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@ limitations under the License.
1616
#include "xla/codegen/math/rsqrt.h"
1717

1818
#include <cstddef>
19+
#include <vector>
1920

2021
#include "absl/log/check.h"
2122
#include "absl/log/log.h"
22-
#include "absl/status/status.h"
2323
#include "absl/status/statusor.h"
24-
#include "absl/strings/str_cat.h"
2524
#include "llvm/ADT/APInt.h"
2625
#include "llvm/IR/Argument.h"
2726
#include "llvm/IR/BasicBlock.h"
@@ -38,8 +37,8 @@ limitations under the License.
3837
#include "llvm/IR/Value.h"
3938
#include "llvm/Support/Casting.h"
4039
#include "llvm/Support/TypeSize.h"
40+
#include "llvm/Target/TargetMachine.h"
4141
#include "xla/codegen/math/intrinsic.h"
42-
#include "xla/service/llvm_ir/llvm_util.h"
4342
#include "xla/xla_data.pb.h"
4443

4544
namespace xla::codegen::intrinsics {
@@ -68,19 +67,19 @@ static llvm::Value* NewtonRaphsonRsqrtIteration(llvm::IRBuilder<>& builder,
6867

6968
struct RsqrtIntrinsic {
7069
llvm::Intrinsic::ID id;
71-
int mask_bits; // Some avx512 calls require masks.
72-
bool needs_insert_element;
70+
int mask_bits; // Some avx512 calls require masks.
71+
int needs_insert_element_size; // Some avx512 calls require padding.
7372

7473
static RsqrtIntrinsic ForF32(size_t num_elements) {
7574
switch (num_elements) {
7675
case 1:
77-
return {llvm::Intrinsic::x86_sse_rsqrt_ss, 0, true};
76+
return {llvm::Intrinsic::x86_sse_rsqrt_ss, 0, 4};
7877
case 4:
79-
return {llvm::Intrinsic::x86_sse_rsqrt_ps, 0, false};
78+
return {llvm::Intrinsic::x86_sse_rsqrt_ps, 0, 0};
8079
case 8:
81-
return {llvm::Intrinsic::x86_avx_rsqrt_ps_256, 0, false};
80+
return {llvm::Intrinsic::x86_avx_rsqrt_ps_256, 0, 0};
8281
case 16:
83-
return {llvm::Intrinsic::x86_avx512_rsqrt14_ps_512, 16, false};
82+
return {llvm::Intrinsic::x86_avx512_rsqrt14_ps_512, 16, 0};
8483
default:
8584
LOG(FATAL) << "Unsupported vector width for rsqrt: " << num_elements;
8685
}
@@ -89,12 +88,18 @@ struct RsqrtIntrinsic {
8988
static RsqrtIntrinsic ForF64(size_t num_elements) {
9089
// We assume AVX512 is available for F64.
9190
switch (num_elements) {
91+
case 1:
92+
// Assuming AVX512 is available.
93+
// We don't use x86_avx512_rsqrt14_sd because it also requires padding
94+
// into <2 x double> vectors and it takes an additional source vector
95+
// for the upper bits of the result.
96+
return {llvm::Intrinsic::x86_avx512_rsqrt14_pd_128, 8, 2};
9297
case 2:
93-
return {llvm::Intrinsic::x86_avx512_rsqrt14_pd_128, 8, false};
98+
return {llvm::Intrinsic::x86_avx512_rsqrt14_pd_128, 8, 0};
9499
case 4:
95-
return {llvm::Intrinsic::x86_avx512_rsqrt14_pd_256, 8, false};
100+
return {llvm::Intrinsic::x86_avx512_rsqrt14_pd_256, 8, 0};
96101
case 8:
97-
return {llvm::Intrinsic::x86_avx512_rsqrt14_pd_512, 8, false};
102+
return {llvm::Intrinsic::x86_avx512_rsqrt14_pd_512, 8, 0};
98103
default:
99104
LOG(FATAL) << "Unsupported vector width for rsqrt: " << num_elements;
100105
}
@@ -106,36 +111,39 @@ struct RsqrtIntrinsic {
106111
llvm::Intrinsic::getOrInsertDeclaration(module, id);
107112

108113
llvm::Value* y_approx;
109-
if (needs_insert_element) {
114+
std::vector<llvm::Value*> args = {x};
115+
if (needs_insert_element_size > 0) {
116+
// Pad into a vector of size `needs_insert_element_size`.
110117
llvm::Type* sse_vec_type = llvm::VectorType::get(
111-
x->getType()->getScalarType(), llvm::ElementCount::getFixed(4));
118+
x->getType()->getScalarType(),
119+
llvm::ElementCount::getFixed(needs_insert_element_size));
112120
llvm::Value* vec_x = llvm::UndefValue::get(sse_vec_type);
113121
vec_x = builder.CreateInsertElement(vec_x, x, builder.getInt32(0));
114-
llvm::Value* approx_vec =
115-
builder.CreateCall(rsqrt_intrinsic, {vec_x}, "y_approx.vec");
116-
y_approx = builder.CreateExtractElement(approx_vec, builder.getInt32(0),
117-
"y_approx");
118-
} else if (mask_bits > 0) {
119-
llvm::Value* dest = llvm::ConstantFP::get(x->getType(), 0.0);
122+
args[0] = vec_x;
123+
}
124+
if (mask_bits > 0) {
125+
llvm::Value* src = llvm::ConstantFP::get(args[0]->getType(), 0.0);
120126
llvm::Value* mask = llvm::ConstantInt::get(
121127
builder.getContext(), llvm::APInt(mask_bits, -1, true));
122-
y_approx =
123-
builder.CreateCall(rsqrt_intrinsic, {x, dest, mask}, "y_approx");
124-
125-
} else {
126-
y_approx = builder.CreateCall(rsqrt_intrinsic, {x}, "y_approx");
128+
args.push_back(src);
129+
args.push_back(mask);
130+
}
131+
y_approx = builder.CreateCall(rsqrt_intrinsic, args, "y_approx");
132+
if (needs_insert_element_size > 0) {
133+
// Extract the result from the padded vector.
134+
y_approx = builder.CreateExtractElement(y_approx, builder.getInt32(0),
135+
"y_approx");
127136
}
128137
return y_approx;
129138
}
130139
};
131140

132-
absl::StatusOr<llvm::Function*> Rsqrt::CreateDefinition(llvm::Module* module,
133-
Type type) {
141+
absl::StatusOr<llvm::Function*> Rsqrt::CreateDefinition(
142+
llvm::Module* module, llvm::TargetMachine* target_machine, Type type) {
143+
CHECK(type.element_type() == F64 || type.element_type() == F32)
144+
<< type.name();
134145
llvm::Type* input_type = Type::TypeToIrType(type, module->getContext());
135146
CHECK(input_type != nullptr);
136-
CHECK(input_type->isFloatingPointTy() || input_type->isVectorTy());
137-
CHECK(input_type->getScalarType()->isFloatTy() ||
138-
input_type->getScalarType()->isDoubleTy());
139147

140148
llvm::LLVMContext& context = module->getContext();
141149
llvm::IRBuilder<> builder(context);
@@ -151,30 +159,44 @@ absl::StatusOr<llvm::Function*> Rsqrt::CreateDefinition(llvm::Module* module,
151159
module->getOrInsertFunction(Rsqrt::Name(type), function_type)
152160
.getCallee());
153161

154-
llvm::Argument* input_x_arg = func->getArg(0);
155-
input_x_arg->setName("x");
162+
llvm::Argument* x = func->getArg(0);
163+
x->setName("x");
156164
llvm::BasicBlock* entry_bb = llvm::BasicBlock::Create(context, "entry", func);
157-
llvm::Value* x = input_x_arg;
158165
builder.SetInsertPoint(entry_bb);
166+
167+
if ((type.element_type() == F64 &&
168+
!target_machine->getTargetFeatureString().contains("+avx512f")) ||
169+
!target_machine->getTargetFeatureString().contains("+avx")) {
170+
LOG_EVERY_N(INFO, 1000)
171+
<< "avx not available, falling back to 1 / sqrt(x) for " << type.name();
172+
// We can't use the same approximation algorithm for F64 without AVX512 or
173+
// anything non-x86 and without avx.
174+
llvm::Value* one = llvm::ConstantFP::get(input_type, 1.0);
175+
llvm::Value* sqrt_x =
176+
builder.CreateUnaryIntrinsic(llvm::Intrinsic::sqrt, x);
177+
llvm::Value* inv_sqrt_x = builder.CreateFDiv(one, sqrt_x, "inv_sqrt_x");
178+
builder.CreateRet(inv_sqrt_x);
179+
return func;
180+
}
181+
159182
RsqrtIntrinsic rsqrt_intrinsic = input_type->getScalarType()->isFloatTy()
160183
? RsqrtIntrinsic::ForF32(num_elements)
161184
: RsqrtIntrinsic::ForF64(num_elements);
162185
llvm::Value* y_approx = rsqrt_intrinsic.CreateCall(builder, x);
163186

164187
llvm::Value* refined_result =
165-
NewtonRaphsonRsqrtIteration(builder, input_x_arg, y_approx, input_type);
188+
NewtonRaphsonRsqrtIteration(builder, x, y_approx, input_type);
166189
if (input_type->getScalarType()->isDoubleTy()) {
167190
// Do an additional refinement step for F64.
168-
refined_result = NewtonRaphsonRsqrtIteration(builder, input_x_arg,
169-
refined_result, input_type);
191+
refined_result =
192+
NewtonRaphsonRsqrtIteration(builder, x, refined_result, input_type);
170193
}
171194

172-
// Create a mask for special cases (denormals and infinities) to fall back
173-
// to the intrinsic's result, matching Eigen's behavior.
174195
const llvm::fltSemantics& semantics =
175196
input_type->getScalarType()->getFltSemantics();
176-
llvm::Constant* flt_min = llvm::ConstantFP::get(
177-
input_type, llvm::APFloat::getSmallestNormalized(semantics));
197+
llvm::APFloat flt_min_val = llvm::APFloat::getSmallestNormalized(semantics);
198+
llvm::Constant* flt_min = llvm::ConstantFP::get(input_type, flt_min_val);
199+
178200
llvm::Constant* inf =
179201
llvm::ConstantFP::get(input_type, llvm::APFloat::getInf(semantics));
180202

@@ -183,11 +205,24 @@ absl::StatusOr<llvm::Function*> Rsqrt::CreateDefinition(llvm::Module* module,
183205
llvm::Value* use_hw_approx_mask =
184206
builder.CreateOr(lt_min_mask, inf_mask, "use_hw_approx_mask");
185207

208+
if (type.element_type() == F64) {
209+
// If the input is very large, the result should be 0.
210+
// This is effectively calculating 1.0 / FLT_MIN, which is the threshold
211+
// where rsqrt(x*x) would be close to the smallest representable number.
212+
llvm::APFloat rsqrt_is_zero_val = llvm::APFloat(semantics, 1);
213+
rsqrt_is_zero_val.divide(flt_min_val, llvm::APFloat::rmTowardZero);
214+
llvm::Constant* rsqrt_is_zero =
215+
llvm::ConstantFP::get(input_type, rsqrt_is_zero_val);
216+
llvm::Value* rsqrt_is_zero_mask =
217+
builder.CreateFCmpOGE(x, rsqrt_is_zero, "rsqrt_is_zero_mask");
218+
use_hw_approx_mask =
219+
builder.CreateOr(use_hw_approx_mask, rsqrt_is_zero_mask);
220+
}
221+
186222
// If input is normal and finite, use the refined result. Otherwise, use the
187223
// raw hardware approximation.
188224
llvm::Value* result = builder.CreateSelect(use_hw_approx_mask, y_approx,
189225
refined_result, "result");
190-
191226
builder.CreateRet(result);
192227
return func;
193228
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy