Skip to content

[XLA:GPU] Enable strength reduction for s32xs32->s32 dots #97356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions third_party/xla/xla/service/gpu/gpu_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1815,17 +1815,18 @@ TEST_F(GpuCompilerTest,
// CHECK: %[[bitcast:.+]] = {{.+}} bitcast(%[[rs]])
// CHECK: ROOT {{.+}} = {{.+}} dynamic-update-slice({{.+}}, %[[bitcast]], {{.+}})
// CHECK: ENTRY
// CHECK: %[[wrapped_transpose:.+]] = {{.+}} fusion({{.+}}), kind=kInput
// CHECK: %[[input_reduce_fusion:.+]] = {{.+}} fusion({{.+}}), kind=kInput
// CHECK: %[[fusion_start:.+]] = {{.+}} fusion-start({{.+}}), kind=kCustom, {{.+}}"name":"dynamic_address_computation"
// CHECK-NEXT: %[[wrapped_dot:.+]] = {{.+}} fusion({{.+}}), kind=kLoop
// CHECK-NEXT: %[[fusion_done:.+]] = {{.+}} fusion-done(%[[fusion_start]]), {{.+}}"name":"dynamic_address_computation"
// CHECK: ROOT {{.+}} = {{.+}} tuple(%[[fusion_done]], %[[wrapped_dot]])
// CHECK: ROOT {{.+}} = {{.+}} tuple(%[[fusion_done]], %[[input_reduce_fusion]])
)";
EXPECT_THAT(
RunFileCheck(exec->module().ToString(HloPrintOptions{}
.set_print_operand_shape(false)
.set_print_metadata(false)),
kExpected),
::tsl::testing::IsOkAndHolds(true));
auto output = exec->module().ToString(
HloPrintOptions{}.set_print_operand_shape(false).set_print_metadata(
false));
VLOG(0) << "output: " << output;
EXPECT_THAT(RunFileCheck(output, kExpected),
::tsl::testing::IsOkAndHolds(true));

if (test_runner().device_count() < 2) {
GTEST_SKIP() << "Skipping test as it requires at least 2 devices.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,12 @@ bool DotStrengthReduction::InstructionMatchesPattern(
const bool rhs_is_vector = (dnums.rhs_batch_dimensions_size() +
dnums.rhs_contracting_dimensions_size() ==
rhs->shape().dimensions().size());
if (!lhs_is_vector && !rhs_is_vector) {
const bool are_both_operands_and_result_s32 =
lhs->shape().element_type() == S32 &&
rhs->shape().element_type() == S32 && dot->shape().element_type() == S32;
// For s32xs32->s32 dots, since its not supported by the h/w we want to
// strength reduce it in any case.
if (!lhs_is_vector && !rhs_is_vector && !are_both_operands_and_result_s32) {
return false;
}
// Strength-reduce vector-vector dots since they are not supported by
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,50 @@ TEST_F(DotStrengthReductionTest, DotStrengthReductionMixedOperandTypes) {
EXPECT_TRUE(filecheck_result);
}

TEST_F(DotStrengthReductionTest, S32MatrixMatrixDotShouldBeStrengthReduced) {
const std::string& hlo_string = R"(
HloModule m

ENTRY entry {
p0 = s32[32, 50] parameter(0)
p1 = s32[50, 70] parameter(1)
ROOT dot = s32[32, 70] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
DotStrengthReduction pass{
se::GpuComputeCapability(se::CudaComputeCapability::Ampere())};
TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
EXPECT_TRUE(changed);
CHECK_OK(module->Verify());

const char* filecheck_pattern = R"(
// CHECK: s32[32,70,50]{{[^ ]*}} multiply
// CHECK: s32[32,70]{{[^ ]*}} reduce
)";

TF_ASSERT_OK_AND_ASSIGN(bool filecheck_result,
RunFileCheck(module->ToString(), filecheck_pattern));
EXPECT_TRUE(filecheck_result);
}

TEST_F(DotStrengthReductionTest, F32MatrixMatrixDotShouldNotBeStrengthReduced) {
const std::string& hlo_string = R"(
HloModule m

ENTRY entry {
p0 = f32[32, 500] parameter(0)
p1 = f32[500, 700] parameter(1)
ROOT dot = f32[32, 700] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
DotStrengthReduction pass{
se::GpuComputeCapability(se::CudaComputeCapability::Ampere())};
TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
EXPECT_FALSE(changed);
}

} // namespace
} // namespace gpu
} // namespace xla
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy