@@ -16,12 +16,11 @@ limitations under the License.
16
16
#include " xla/codegen/math/rsqrt.h"
17
17
18
18
#include < cstddef>
19
+ #include < vector>
19
20
20
21
#include " absl/log/check.h"
21
22
#include " absl/log/log.h"
22
- #include " absl/status/status.h"
23
23
#include " absl/status/statusor.h"
24
- #include " absl/strings/str_cat.h"
25
24
#include " llvm/ADT/APInt.h"
26
25
#include " llvm/IR/Argument.h"
27
26
#include " llvm/IR/BasicBlock.h"
@@ -38,8 +37,8 @@ limitations under the License.
38
37
#include " llvm/IR/Value.h"
39
38
#include " llvm/Support/Casting.h"
40
39
#include " llvm/Support/TypeSize.h"
40
+ #include " llvm/Target/TargetMachine.h"
41
41
#include " xla/codegen/math/intrinsic.h"
42
- #include " xla/service/llvm_ir/llvm_util.h"
43
42
#include " xla/xla_data.pb.h"
44
43
45
44
namespace xla ::codegen::intrinsics {
@@ -68,19 +67,19 @@ static llvm::Value* NewtonRaphsonRsqrtIteration(llvm::IRBuilder<>& builder,
68
67
69
68
struct RsqrtIntrinsic {
70
69
llvm::Intrinsic::ID id;
71
- int mask_bits; // Some avx512 calls require masks.
72
- bool needs_insert_element;
70
+ int mask_bits; // Some avx512 calls require masks.
71
+ int needs_insert_element_size; // Some avx512 calls require padding.
73
72
74
73
static RsqrtIntrinsic ForF32 (size_t num_elements) {
75
74
switch (num_elements) {
76
75
case 1 :
77
- return {llvm::Intrinsic::x86_sse_rsqrt_ss, 0 , true };
76
+ return {llvm::Intrinsic::x86_sse_rsqrt_ss, 0 , 4 };
78
77
case 4 :
79
- return {llvm::Intrinsic::x86_sse_rsqrt_ps, 0 , false };
78
+ return {llvm::Intrinsic::x86_sse_rsqrt_ps, 0 , 0 };
80
79
case 8 :
81
- return {llvm::Intrinsic::x86_avx_rsqrt_ps_256, 0 , false };
80
+ return {llvm::Intrinsic::x86_avx_rsqrt_ps_256, 0 , 0 };
82
81
case 16 :
83
- return {llvm::Intrinsic::x86_avx512_rsqrt14_ps_512, 16 , false };
82
+ return {llvm::Intrinsic::x86_avx512_rsqrt14_ps_512, 16 , 0 };
84
83
default :
85
84
LOG (FATAL) << " Unsupported vector width for rsqrt: " << num_elements;
86
85
}
@@ -89,12 +88,18 @@ struct RsqrtIntrinsic {
89
88
static RsqrtIntrinsic ForF64 (size_t num_elements) {
90
89
// We assume AVX512 is available for F64.
91
90
switch (num_elements) {
91
+ case 1 :
92
+ // Assuming AVX512 is available.
93
+ // We don't use x86_avx512_rsqrt14_sd because it also requires padding
94
+ // into <2 x double> vectors and it takes an additional source vector
95
+ // for the upper bits of the result.
96
+ return {llvm::Intrinsic::x86_avx512_rsqrt14_pd_128, 8 , 2 };
92
97
case 2 :
93
- return {llvm::Intrinsic::x86_avx512_rsqrt14_pd_128, 8 , false };
98
+ return {llvm::Intrinsic::x86_avx512_rsqrt14_pd_128, 8 , 0 };
94
99
case 4 :
95
- return {llvm::Intrinsic::x86_avx512_rsqrt14_pd_256, 8 , false };
100
+ return {llvm::Intrinsic::x86_avx512_rsqrt14_pd_256, 8 , 0 };
96
101
case 8 :
97
- return {llvm::Intrinsic::x86_avx512_rsqrt14_pd_512, 8 , false };
102
+ return {llvm::Intrinsic::x86_avx512_rsqrt14_pd_512, 8 , 0 };
98
103
default :
99
104
LOG (FATAL) << " Unsupported vector width for rsqrt: " << num_elements;
100
105
}
@@ -106,36 +111,39 @@ struct RsqrtIntrinsic {
106
111
llvm::Intrinsic::getOrInsertDeclaration (module , id);
107
112
108
113
llvm::Value* y_approx;
109
- if (needs_insert_element) {
114
+ std::vector<llvm::Value*> args = {x};
115
+ if (needs_insert_element_size > 0 ) {
116
+ // Pad into a vector of size `needs_insert_element_size`.
110
117
llvm::Type* sse_vec_type = llvm::VectorType::get (
111
- x->getType ()->getScalarType (), llvm::ElementCount::getFixed (4 ));
118
+ x->getType ()->getScalarType (),
119
+ llvm::ElementCount::getFixed (needs_insert_element_size));
112
120
llvm::Value* vec_x = llvm::UndefValue::get (sse_vec_type);
113
121
vec_x = builder.CreateInsertElement (vec_x, x, builder.getInt32 (0 ));
114
- llvm::Value* approx_vec =
115
- builder.CreateCall (rsqrt_intrinsic, {vec_x}, " y_approx.vec" );
116
- y_approx = builder.CreateExtractElement (approx_vec, builder.getInt32 (0 ),
117
- " y_approx" );
118
- } else if (mask_bits > 0 ) {
119
- llvm::Value* dest = llvm::ConstantFP::get (x->getType (), 0.0 );
122
+ args[0 ] = vec_x;
123
+ }
124
+ if (mask_bits > 0 ) {
125
+ llvm::Value* src = llvm::ConstantFP::get (args[0 ]->getType (), 0.0 );
120
126
llvm::Value* mask = llvm::ConstantInt::get (
121
127
builder.getContext (), llvm::APInt (mask_bits, -1 , true ));
122
- y_approx =
123
- builder.CreateCall (rsqrt_intrinsic, {x, dest, mask}, " y_approx" );
124
-
125
- } else {
126
- y_approx = builder.CreateCall (rsqrt_intrinsic, {x}, " y_approx" );
128
+ args.push_back (src);
129
+ args.push_back (mask);
130
+ }
131
+ y_approx = builder.CreateCall (rsqrt_intrinsic, args, " y_approx" );
132
+ if (needs_insert_element_size > 0 ) {
133
+ // Extract the result from the padded vector.
134
+ y_approx = builder.CreateExtractElement (y_approx, builder.getInt32 (0 ),
135
+ " y_approx" );
127
136
}
128
137
return y_approx;
129
138
}
130
139
};
131
140
132
- absl::StatusOr<llvm::Function*> Rsqrt::CreateDefinition (llvm::Module* module ,
133
- Type type) {
141
+ absl::StatusOr<llvm::Function*> Rsqrt::CreateDefinition (
142
+ llvm::Module* module , llvm::TargetMachine* target_machine, Type type) {
143
+ CHECK (type.element_type () == F64 || type.element_type () == F32)
144
+ << type.name ();
134
145
llvm::Type* input_type = Type::TypeToIrType (type, module ->getContext ());
135
146
CHECK (input_type != nullptr );
136
- CHECK (input_type->isFloatingPointTy () || input_type->isVectorTy ());
137
- CHECK (input_type->getScalarType ()->isFloatTy () ||
138
- input_type->getScalarType ()->isDoubleTy ());
139
147
140
148
llvm::LLVMContext& context = module ->getContext ();
141
149
llvm::IRBuilder<> builder (context);
@@ -151,30 +159,44 @@ absl::StatusOr<llvm::Function*> Rsqrt::CreateDefinition(llvm::Module* module,
151
159
module ->getOrInsertFunction (Rsqrt::Name (type), function_type)
152
160
.getCallee ());
153
161
154
- llvm::Argument* input_x_arg = func->getArg (0 );
155
- input_x_arg ->setName (" x" );
162
+ llvm::Argument* x = func->getArg (0 );
163
+ x ->setName (" x" );
156
164
llvm::BasicBlock* entry_bb = llvm::BasicBlock::Create (context, " entry" , func);
157
- llvm::Value* x = input_x_arg;
158
165
builder.SetInsertPoint (entry_bb);
166
+
167
+ if ((type.element_type () == F64 &&
168
+ !target_machine->getTargetFeatureString ().contains (" +avx512f" )) ||
169
+ !target_machine->getTargetFeatureString ().contains (" +avx" )) {
170
+ LOG_EVERY_N (INFO, 1000 )
171
+ << " avx not available, falling back to 1 / sqrt(x) for " << type.name ();
172
+ // We can't use the same approximation algorithm for F64 without AVX512 or
173
+ // anything non-x86 and without avx.
174
+ llvm::Value* one = llvm::ConstantFP::get (input_type, 1.0 );
175
+ llvm::Value* sqrt_x =
176
+ builder.CreateUnaryIntrinsic (llvm::Intrinsic::sqrt, x);
177
+ llvm::Value* inv_sqrt_x = builder.CreateFDiv (one, sqrt_x, " inv_sqrt_x" );
178
+ builder.CreateRet (inv_sqrt_x);
179
+ return func;
180
+ }
181
+
159
182
RsqrtIntrinsic rsqrt_intrinsic = input_type->getScalarType ()->isFloatTy ()
160
183
? RsqrtIntrinsic::ForF32 (num_elements)
161
184
: RsqrtIntrinsic::ForF64 (num_elements);
162
185
llvm::Value* y_approx = rsqrt_intrinsic.CreateCall (builder, x);
163
186
164
187
llvm::Value* refined_result =
165
- NewtonRaphsonRsqrtIteration (builder, input_x_arg , y_approx, input_type);
188
+ NewtonRaphsonRsqrtIteration (builder, x , y_approx, input_type);
166
189
if (input_type->getScalarType ()->isDoubleTy ()) {
167
190
// Do an additional refinement step for F64.
168
- refined_result = NewtonRaphsonRsqrtIteration (builder, input_x_arg,
169
- refined_result, input_type);
191
+ refined_result =
192
+ NewtonRaphsonRsqrtIteration (builder, x, refined_result, input_type);
170
193
}
171
194
172
- // Create a mask for special cases (denormals and infinities) to fall back
173
- // to the intrinsic's result, matching Eigen's behavior.
174
195
const llvm::fltSemantics& semantics =
175
196
input_type->getScalarType ()->getFltSemantics ();
176
- llvm::Constant* flt_min = llvm::ConstantFP::get (
177
- input_type, llvm::APFloat::getSmallestNormalized (semantics));
197
+ llvm::APFloat flt_min_val = llvm::APFloat::getSmallestNormalized (semantics);
198
+ llvm::Constant* flt_min = llvm::ConstantFP::get (input_type, flt_min_val);
199
+
178
200
llvm::Constant* inf =
179
201
llvm::ConstantFP::get (input_type, llvm::APFloat::getInf (semantics));
180
202
@@ -183,11 +205,24 @@ absl::StatusOr<llvm::Function*> Rsqrt::CreateDefinition(llvm::Module* module,
183
205
llvm::Value* use_hw_approx_mask =
184
206
builder.CreateOr (lt_min_mask, inf_mask, " use_hw_approx_mask" );
185
207
208
+ if (type.element_type () == F64) {
209
+ // If the input is very large, the result should be 0.
210
+ // This is effectively calculating 1.0 / FLT_MIN, which is the threshold
211
+ // where rsqrt(x*x) would be close to the smallest representable number.
212
+ llvm::APFloat rsqrt_is_zero_val = llvm::APFloat (semantics, 1 );
213
+ rsqrt_is_zero_val.divide (flt_min_val, llvm::APFloat::rmTowardZero);
214
+ llvm::Constant* rsqrt_is_zero =
215
+ llvm::ConstantFP::get (input_type, rsqrt_is_zero_val);
216
+ llvm::Value* rsqrt_is_zero_mask =
217
+ builder.CreateFCmpOGE (x, rsqrt_is_zero, " rsqrt_is_zero_mask" );
218
+ use_hw_approx_mask =
219
+ builder.CreateOr (use_hw_approx_mask, rsqrt_is_zero_mask);
220
+ }
221
+
186
222
// If input is normal and finite, use the refined result. Otherwise, use the
187
223
// raw hardware approximation.
188
224
llvm::Value* result = builder.CreateSelect (use_hw_approx_mask, y_approx,
189
225
refined_result, " result" );
190
-
191
226
builder.CreateRet (result);
192
227
return func;
193
228
}
0 commit comments