Skip to content

Parse and ignore 'mode' attribute on collectives. #97392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions third_party/xla/xla/hlo/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ cc_library(
],
deps = [
":backend_config",
":collective_op_group_mode",
":hlo_op_metadata",
":hlo_sharding",
":ptrvec",
Expand Down Expand Up @@ -374,3 +375,27 @@ xla_cc_test(
"@com_google_googletest//:gtest",
],
)

cc_library(
name = "collective_op_group_mode",
srcs = ["collective_op_group_mode.cc"],
hdrs = ["collective_op_group_mode.h"],
deps = [
"//xla:util",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
],
)

xla_cc_test(
name = "collective_op_group_mode_test",
srcs = ["collective_op_group_mode_test.cc"],
deps = [
":collective_op_group_mode",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/status:statusor",
"@com_google_googletest//:gtest",
],
)
84 changes: 84 additions & 0 deletions third_party/xla/xla/hlo/ir/collective_op_group_mode.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/hlo/ir/collective_op_group_mode.h"

#include <optional>

#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/util.h"

namespace xla {
namespace {

struct CollectiveOpGroupModeInfo {
CollectiveOpGroupMode mode;
absl::string_view name;
};

const CollectiveOpGroupModeInfo kGroupModeInfos[] = {
{CollectiveOpGroupMode::kCrossReplica, "cross_replica"},
{CollectiveOpGroupMode::kCrossPartition, "cross_partition"},
{CollectiveOpGroupMode::kCrossReplicaAndPartition,
"cross_replica_and_partition"},
{CollectiveOpGroupMode::kFlattenedID, "flattened_id"},
};

} // namespace

absl::string_view CollectiveOpGroupModeToString(
CollectiveOpGroupMode group_mode) {
for (const CollectiveOpGroupModeInfo& info : kGroupModeInfos) {
if (info.mode == group_mode) {
return info.name;
}
}
CHECK(false) << "Unknown collective op group mode: "
<< static_cast<int>(group_mode);
}

absl::StatusOr<CollectiveOpGroupMode> StringToCollectiveOpGroupMode(
absl::string_view name) {
for (const CollectiveOpGroupModeInfo& info : kGroupModeInfos) {
if (info.name == name) {
return info.mode;
}
}
return InvalidArgument("Invalid collective op group mode: %s", name);
}

// Returns the group formation mode implied by (a) whether the operation has
// channel_id and (b) if it has use_global_device_ids and if yes, its value.
absl::StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
bool has_channel_id, std::optional<bool> use_global_device_ids) {
if (!has_channel_id) {
if (use_global_device_ids.has_value() && *use_global_device_ids) {
return InvalidArgument(
"Cannot have use_global_device_ids=true without channel_id");
}
return CollectiveOpGroupMode::kCrossReplica;
}
if (!use_global_device_ids.has_value()) {
return CollectiveOpGroupMode::kCrossPartition;
}
if (!*use_global_device_ids) {
return CollectiveOpGroupMode::kCrossReplicaAndPartition;
}
return CollectiveOpGroupMode::kFlattenedID;
}

} // namespace xla
79 changes: 79 additions & 0 deletions third_party/xla/xla/hlo/ir/collective_op_group_mode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_HLO_IR_COLLECTIVE_OP_GROUP_MODE_H_
#define XLA_HLO_IR_COLLECTIVE_OP_GROUP_MODE_H_

#include <optional>

#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"

namespace xla {

// There are broadly 4 modes that collective communication ops use to describe
// which sets of devices are participating with a given device in the operation.
// These modes are determined by the values of channel_id (optional) and
// use_global_device_ids (optional). The modes are as follows:
//
// kCrossReplica:
// implied by: no channel id, use_global_device_ids = false, or
// no channel_id, no use_global_device_ids:
// replica_groups contain replica_id, group contains all replicas for the
// current partition
//
// kCrossPartition:
// implied by: channel_id is set, no use_global_device_ids:
// replica_groups contain partition_id, group contains all partitions for the
// current replica.
//
// kCrossReplicaAndPartition:
// implied by: channel_id is set, use_global_device_ids = false:
// replica_groups contain replica_id, group contains all replicas for all
// partitions (as opposed to just current partition).
//
// kFlattenedID:
// implied by: channel_id is set, use_global_device_ids = true:
// replica_groups contain flattened-ids, group contains devices that are
// listed in the flattened-id list.
//
// Rest of the combinations are invalid.
//
// Since the actual value of channel_id does not matter, we use a bool argument
// `has_channel_id`, and optional<bool> for use_global_device_ids.
// Note that use_global_device_ids true requires channel_id to be set as well.
// Additionally, if use_global_device_ids = true, replica groups cannot be
// empty (verified in the HLO verifier).
enum class CollectiveOpGroupMode {
kCrossReplica,
kCrossPartition,
kCrossReplicaAndPartition,
kFlattenedID,
};

absl::string_view CollectiveOpGroupModeToString(
CollectiveOpGroupMode group_mode);

absl::StatusOr<CollectiveOpGroupMode> StringToCollectiveOpGroupMode(
absl::string_view name);

// Returns the group formation mode implied by (a) whether the operation has
// channel_id and (b) if it has use_global_device_ids and if yes, its value.
absl::StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
bool has_channel_id, std::optional<bool> use_global_device_ids);

} // namespace xla

#endif // XLA_HLO_IR_COLLECTIVE_OP_GROUP_MODE_H_
106 changes: 106 additions & 0 deletions third_party/xla/xla/hlo/ir/collective_op_group_mode_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/hlo/ir/collective_op_group_mode.h"

#include <optional>
#include <sstream>
#include <string>
#include <vector>

#include <gtest/gtest.h>
#include "absl/status/statusor.h"
#include "xla/tsl/lib/core/status_test_util.h"

namespace xla {
namespace {

TEST(CollectiveOpGroupModeTest, ToString) {
EXPECT_EQ(CollectiveOpGroupModeToString(CollectiveOpGroupMode::kCrossReplica),
"cross_replica");
EXPECT_EQ(
CollectiveOpGroupModeToString(CollectiveOpGroupMode::kCrossPartition),
"cross_partition");
EXPECT_EQ(CollectiveOpGroupModeToString(
CollectiveOpGroupMode::kCrossReplicaAndPartition),
"cross_replica_and_partition");
EXPECT_EQ(CollectiveOpGroupModeToString(CollectiveOpGroupMode::kFlattenedID),
"flattened_id");
}

TEST(CollectiveOpGroupModeTest, FromString) {
EXPECT_EQ(StringToCollectiveOpGroupMode("cross_replica").value(),
CollectiveOpGroupMode::kCrossReplica);
EXPECT_EQ(StringToCollectiveOpGroupMode("cross_partition").value(),
CollectiveOpGroupMode::kCrossPartition);
EXPECT_EQ(
StringToCollectiveOpGroupMode("cross_replica_and_partition").value(),
CollectiveOpGroupMode::kCrossReplicaAndPartition);
EXPECT_EQ(StringToCollectiveOpGroupMode("flattened_id").value(),
CollectiveOpGroupMode::kFlattenedID);
}

// Tests for GetCollectOpGroupMode
namespace GetCollectiveOpGroupModeTest {
struct TestCase {
bool has_channel_id;
std::optional<bool> use_global_device_ids;
std::optional<xla::CollectiveOpGroupMode> expected;

std::string ToString() const {
std::ostringstream s;
s << (has_channel_id ? "chnl" : "nochnl");
s << "_"
<< (use_global_device_ids
? (*use_global_device_ids ? "ugdi_true" : "ugdi_false")
: "nougdi");
return s.str();
}
};

std::vector<TestCase> GetTestCases() {
const std::vector<TestCase> test_cases = {
// clang-format off
// has_channel_id, use_global_device_ids, expected mode
{false, std::nullopt, CollectiveOpGroupMode::kCrossReplica},
{false, false, CollectiveOpGroupMode::kCrossReplica},
{false, true, std::nullopt},
{true, std::nullopt, CollectiveOpGroupMode::kCrossPartition},
{true, false, CollectiveOpGroupMode::kCrossReplicaAndPartition},
{true, true, CollectiveOpGroupMode::kFlattenedID},
// clang-format on
};
return test_cases;
}

class GetCollectOpGroupModeTest : public testing::TestWithParam<TestCase> {};

TEST_P(GetCollectOpGroupModeTest, Test) {
const TestCase &tc = GetParam();
absl::StatusOr<CollectiveOpGroupMode> actual =
GetCollectiveOpGroupMode(tc.has_channel_id, tc.use_global_device_ids);
if (tc.expected) {
TF_ASSERT_OK(actual.status());
EXPECT_EQ(*actual, *tc.expected);
} else {
EXPECT_FALSE(actual.ok());
}
}

INSTANTIATE_TEST_SUITE_P(GetCollectOpGroupMode, GetCollectOpGroupModeTest,
testing::ValuesIn(GetTestCases()));
} // namespace GetCollectiveOpGroupModeTest
} // namespace
} // namespace xla
3 changes: 3 additions & 0 deletions third_party/xla/xla/hlo/parser/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ cc_library(
"//xla:types",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:collective_op_group_mode",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:tile_assignment",
"//xla/service:computation_layout",
Expand Down Expand Up @@ -76,6 +77,7 @@ xla_cc_test(
"//xla:window_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/builder:xla_builder",
"//xla/hlo/ir:collective_op_group_mode",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:pattern_matcher_gmock",
"//xla/hlo/testlib:verified_hlo_module",
Expand All @@ -92,6 +94,7 @@ xla_cc_test(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest",
],
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy