Skip to content

Commit 997cd71

Browse files
[XLA:GPU] Enable strength reduction for s32xs32->s32 dots
Changing s32 to elementwise -> reduce can make it faster when run through Triton emitter. Change to move this to triton emitter will come in a followup. PiperOrigin-RevId: 785835110
1 parent 943eefb commit 997cd71

File tree

3 files changed

+59
-9
lines changed

3 files changed

+59
-9
lines changed

third_party/xla/xla/service/gpu/gpu_compiler_test.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1815,17 +1815,18 @@ TEST_F(GpuCompilerTest,
18151815
// CHECK: %[[bitcast:.+]] = {{.+}} bitcast(%[[rs]])
18161816
// CHECK: ROOT {{.+}} = {{.+}} dynamic-update-slice({{.+}}, %[[bitcast]], {{.+}})
18171817
// CHECK: ENTRY
1818+
// CHECK: %[[wrapped_transpose:.+]] = {{.+}} fusion({{.+}}), kind=kInput
1819+
// CHECK: %[[input_reduce_fusion:.+]] = {{.+}} fusion({{.+}}), kind=kInput
18181820
// CHECK: %[[fusion_start:.+]] = {{.+}} fusion-start({{.+}}), kind=kCustom, {{.+}}"name":"dynamic_address_computation"
1819-
// CHECK-NEXT: %[[wrapped_dot:.+]] = {{.+}} fusion({{.+}}), kind=kLoop
18201821
// CHECK-NEXT: %[[fusion_done:.+]] = {{.+}} fusion-done(%[[fusion_start]]), {{.+}}"name":"dynamic_address_computation"
1821-
// CHECK: ROOT {{.+}} = {{.+}} tuple(%[[fusion_done]], %[[wrapped_dot]])
1822+
// CHECK: ROOT {{.+}} = {{.+}} tuple(%[[fusion_done]], %[[input_reduce_fusion]])
18221823
)";
1823-
EXPECT_THAT(
1824-
RunFileCheck(exec->module().ToString(HloPrintOptions{}
1825-
.set_print_operand_shape(false)
1826-
.set_print_metadata(false)),
1827-
kExpected),
1828-
::tsl::testing::IsOkAndHolds(true));
1824+
auto output = exec->module().ToString(
1825+
HloPrintOptions{}.set_print_operand_shape(false).set_print_metadata(
1826+
false));
1827+
VLOG(0) << "output: " << output;
1828+
EXPECT_THAT(RunFileCheck(output, kExpected),
1829+
::tsl::testing::IsOkAndHolds(true));
18291830

18301831
if (test_runner().device_count() < 2) {
18311832
GTEST_SKIP() << "Skipping test as it requires at least 2 devices.";

third_party/xla/xla/service/gpu/transforms/dot_strength_reduction.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,12 @@ bool DotStrengthReduction::InstructionMatchesPattern(
205205
const bool rhs_is_vector = (dnums.rhs_batch_dimensions_size() +
206206
dnums.rhs_contracting_dimensions_size() ==
207207
rhs->shape().dimensions().size());
208-
if (!lhs_is_vector && !rhs_is_vector) {
208+
const bool are_both_operands_and_result_s32 =
209+
lhs->shape().element_type() == S32 &&
210+
rhs->shape().element_type() == S32 && dot->shape().element_type() == S32;
211+
// For s32xs32->s32 dots, since its not supported by the h/w we want to
212+
// strength reduce it in any case.
213+
if (!lhs_is_vector && !rhs_is_vector && !are_both_operands_and_result_s32) {
209214
return false;
210215
}
211216
// Strength-reduce vector-vector dots since they are not supported by

third_party/xla/xla/service/gpu/transforms/dot_strength_reduction_test.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,50 @@ TEST_F(DotStrengthReductionTest, DotStrengthReductionMixedOperandTypes) {
333333
EXPECT_TRUE(filecheck_result);
334334
}
335335

336+
TEST_F(DotStrengthReductionTest, S32MatrixMatrixDotShouldBeStrengthReduced) {
337+
const std::string& hlo_string = R"(
338+
HloModule m
339+
340+
ENTRY entry {
341+
p0 = s32[32, 50] parameter(0)
342+
p1 = s32[50, 70] parameter(1)
343+
ROOT dot = s32[32, 70] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
344+
})";
345+
TF_ASSERT_OK_AND_ASSIGN(auto module,
346+
ParseAndReturnVerifiedModule(hlo_string));
347+
DotStrengthReduction pass{
348+
se::GpuComputeCapability(se::CudaComputeCapability::Ampere())};
349+
TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
350+
EXPECT_TRUE(changed);
351+
CHECK_OK(module->Verify());
352+
353+
const char* filecheck_pattern = R"(
354+
// CHECK: s32[32,70,50]{{[^ ]*}} multiply
355+
// CHECK: s32[32,70]{{[^ ]*}} reduce
356+
)";
357+
358+
TF_ASSERT_OK_AND_ASSIGN(bool filecheck_result,
359+
RunFileCheck(module->ToString(), filecheck_pattern));
360+
EXPECT_TRUE(filecheck_result);
361+
}
362+
363+
TEST_F(DotStrengthReductionTest, F32MatrixMatrixDotShouldNotBeStrengthReduced) {
364+
const std::string& hlo_string = R"(
365+
HloModule m
366+
367+
ENTRY entry {
368+
p0 = f32[32, 500] parameter(0)
369+
p1 = f32[500, 700] parameter(1)
370+
ROOT dot = f32[32, 700] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
371+
})";
372+
TF_ASSERT_OK_AND_ASSIGN(auto module,
373+
ParseAndReturnVerifiedModule(hlo_string));
374+
DotStrengthReduction pass{
375+
se::GpuComputeCapability(se::CudaComputeCapability::Ampere())};
376+
TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
377+
EXPECT_FALSE(changed);
378+
}
379+
336380
} // namespace
337381
} // namespace gpu
338382
} // namespace xla

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy