Skip to content

Commit f8bac76

Browse files
reedwmtensorflower-gardener
authored andcommitted
Parse and ignore 'mode' attribute on collectives.
The attribute was added and used in openxla/xla@cf3dfa9 but rolled back in openxla/xla@5542ebc since it broke TPU tests. However, HLO text generated between these two commits have the 'mode' attribute. So parse and ignore it for now to not break parsing such HLOs. In the future, once the TPU tests are fixed, the full commit will be rolled forward. Like in the original commit, I move the CollectiveOpGroupMode to it's own file to keep the parsing logic the same as the original commit. PiperOrigin-RevId: 786115058
1 parent f015a18 commit f8bac76

11 files changed

+389
-95
lines changed

third_party/xla/xla/hlo/ir/BUILD

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ cc_library(
5959
],
6060
deps = [
6161
":backend_config",
62+
":collective_op_group_mode",
6263
":hlo_op_metadata",
6364
":hlo_sharding",
6465
":ptrvec",
@@ -374,3 +375,27 @@ xla_cc_test(
374375
"@com_google_googletest//:gtest",
375376
],
376377
)
378+
379+
cc_library(
380+
name = "collective_op_group_mode",
381+
srcs = ["collective_op_group_mode.cc"],
382+
hdrs = ["collective_op_group_mode.h"],
383+
deps = [
384+
"//xla:util",
385+
"@com_google_absl//absl/log:check",
386+
"@com_google_absl//absl/status:statusor",
387+
"@com_google_absl//absl/strings:string_view",
388+
],
389+
)
390+
391+
xla_cc_test(
392+
name = "collective_op_group_mode_test",
393+
srcs = ["collective_op_group_mode_test.cc"],
394+
deps = [
395+
":collective_op_group_mode",
396+
"//xla/tests:xla_internal_test_main",
397+
"//xla/tsl/lib/core:status_test_util",
398+
"@com_google_absl//absl/status:statusor",
399+
"@com_google_googletest//:gtest",
400+
],
401+
)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/hlo/ir/collective_op_group_mode.h"
17+
18+
#include <optional>
19+
20+
#include "absl/log/check.h"
21+
#include "absl/status/statusor.h"
22+
#include "absl/strings/string_view.h"
23+
#include "xla/util.h"
24+
25+
namespace xla {
26+
namespace {
27+
28+
struct CollectiveOpGroupModeInfo {
29+
CollectiveOpGroupMode mode;
30+
absl::string_view name;
31+
};
32+
33+
const CollectiveOpGroupModeInfo kGroupModeInfos[] = {
34+
{CollectiveOpGroupMode::kCrossReplica, "cross_replica"},
35+
{CollectiveOpGroupMode::kCrossPartition, "cross_partition"},
36+
{CollectiveOpGroupMode::kCrossReplicaAndPartition,
37+
"cross_replica_and_partition"},
38+
{CollectiveOpGroupMode::kFlattenedID, "flattened_id"},
39+
};
40+
41+
} // namespace
42+
43+
absl::string_view CollectiveOpGroupModeToString(
44+
CollectiveOpGroupMode group_mode) {
45+
for (const CollectiveOpGroupModeInfo& info : kGroupModeInfos) {
46+
if (info.mode == group_mode) {
47+
return info.name;
48+
}
49+
}
50+
CHECK(false) << "Unknown collective op group mode: "
51+
<< static_cast<int>(group_mode);
52+
}
53+
54+
absl::StatusOr<CollectiveOpGroupMode> StringToCollectiveOpGroupMode(
55+
absl::string_view name) {
56+
for (const CollectiveOpGroupModeInfo& info : kGroupModeInfos) {
57+
if (info.name == name) {
58+
return info.mode;
59+
}
60+
}
61+
return InvalidArgument("Invalid collective op group mode: %s", name);
62+
}
63+
64+
// Returns the group formation mode implied by (a) whether the operation has
65+
// channel_id and (b) if it has use_global_device_ids and if yes, its value.
66+
absl::StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
67+
bool has_channel_id, std::optional<bool> use_global_device_ids) {
68+
if (!has_channel_id) {
69+
if (use_global_device_ids.has_value() && *use_global_device_ids) {
70+
return InvalidArgument(
71+
"Cannot have use_global_device_ids=true without channel_id");
72+
}
73+
return CollectiveOpGroupMode::kCrossReplica;
74+
}
75+
if (!use_global_device_ids.has_value()) {
76+
return CollectiveOpGroupMode::kCrossPartition;
77+
}
78+
if (!*use_global_device_ids) {
79+
return CollectiveOpGroupMode::kCrossReplicaAndPartition;
80+
}
81+
return CollectiveOpGroupMode::kFlattenedID;
82+
}
83+
84+
} // namespace xla
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_HLO_IR_COLLECTIVE_OP_GROUP_MODE_H_
17+
#define XLA_HLO_IR_COLLECTIVE_OP_GROUP_MODE_H_
18+
19+
#include <optional>
20+
21+
#include "absl/status/statusor.h"
22+
#include "absl/strings/string_view.h"
23+
24+
namespace xla {
25+
26+
// There are broadly 4 modes that collective communication ops use to describe
27+
// which sets of devices are participating with a given device in the operation.
28+
// These modes are determined by the values of channel_id (optional) and
29+
// use_global_device_ids (optional). The modes are as follows:
30+
//
31+
// kCrossReplica:
32+
// implied by: no channel id, use_global_device_ids = false, or
33+
// no channel_id, no use_global_device_ids:
34+
// replica_groups contain replica_id, group contains all replicas for the
35+
// current partition
36+
//
37+
// kCrossPartition:
38+
// implied by: channel_id is set, no use_global_device_ids:
39+
// replica_groups contain partition_id, group contains all partitions for the
40+
// current replica.
41+
//
42+
// kCrossReplicaAndPartition:
43+
// implied by: channel_id is set, use_global_device_ids = false:
44+
// replica_groups contain replica_id, group contains all replicas for all
45+
// partitions (as opposed to just current partition).
46+
//
47+
// kFlattenedID:
48+
// implied by: channel_id is set, use_global_device_ids = true:
49+
// replica_groups contain flattened-ids, group contains devices that are
50+
// listed in the flattened-id list.
51+
//
52+
// Rest of the combinations are invalid.
53+
//
54+
// Since the actual value of channel_id does not matter, we use a bool argument
55+
// `has_channel_id`, and optional<bool> for use_global_device_ids.
56+
// Note that use_global_device_ids true requires channel_id to be set as well.
57+
// Additionally, if use_global_device_ids = true, replica groups cannot be
58+
// empty (verified in the HLO verifier).
59+
enum class CollectiveOpGroupMode {
60+
kCrossReplica,
61+
kCrossPartition,
62+
kCrossReplicaAndPartition,
63+
kFlattenedID,
64+
};
65+
66+
absl::string_view CollectiveOpGroupModeToString(
67+
CollectiveOpGroupMode group_mode);
68+
69+
absl::StatusOr<CollectiveOpGroupMode> StringToCollectiveOpGroupMode(
70+
absl::string_view name);
71+
72+
// Returns the group formation mode implied by (a) whether the operation has
73+
// channel_id and (b) if it has use_global_device_ids and if yes, its value.
74+
absl::StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
75+
bool has_channel_id, std::optional<bool> use_global_device_ids);
76+
77+
} // namespace xla
78+
79+
#endif // XLA_HLO_IR_COLLECTIVE_OP_GROUP_MODE_H_
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/hlo/ir/collective_op_group_mode.h"
17+
18+
#include <optional>
19+
#include <sstream>
20+
#include <string>
21+
#include <vector>
22+
23+
#include <gtest/gtest.h>
24+
#include "absl/status/statusor.h"
25+
#include "xla/tsl/lib/core/status_test_util.h"
26+
27+
namespace xla {
28+
namespace {
29+
30+
TEST(CollectiveOpGroupModeTest, ToString) {
31+
EXPECT_EQ(CollectiveOpGroupModeToString(CollectiveOpGroupMode::kCrossReplica),
32+
"cross_replica");
33+
EXPECT_EQ(
34+
CollectiveOpGroupModeToString(CollectiveOpGroupMode::kCrossPartition),
35+
"cross_partition");
36+
EXPECT_EQ(CollectiveOpGroupModeToString(
37+
CollectiveOpGroupMode::kCrossReplicaAndPartition),
38+
"cross_replica_and_partition");
39+
EXPECT_EQ(CollectiveOpGroupModeToString(CollectiveOpGroupMode::kFlattenedID),
40+
"flattened_id");
41+
}
42+
43+
TEST(CollectiveOpGroupModeTest, FromString) {
44+
EXPECT_EQ(StringToCollectiveOpGroupMode("cross_replica").value(),
45+
CollectiveOpGroupMode::kCrossReplica);
46+
EXPECT_EQ(StringToCollectiveOpGroupMode("cross_partition").value(),
47+
CollectiveOpGroupMode::kCrossPartition);
48+
EXPECT_EQ(
49+
StringToCollectiveOpGroupMode("cross_replica_and_partition").value(),
50+
CollectiveOpGroupMode::kCrossReplicaAndPartition);
51+
EXPECT_EQ(StringToCollectiveOpGroupMode("flattened_id").value(),
52+
CollectiveOpGroupMode::kFlattenedID);
53+
}
54+
55+
// Tests for GetCollectOpGroupMode
56+
namespace GetCollectiveOpGroupModeTest {
57+
struct TestCase {
58+
bool has_channel_id;
59+
std::optional<bool> use_global_device_ids;
60+
std::optional<xla::CollectiveOpGroupMode> expected;
61+
62+
std::string ToString() const {
63+
std::ostringstream s;
64+
s << (has_channel_id ? "chnl" : "nochnl");
65+
s << "_"
66+
<< (use_global_device_ids
67+
? (*use_global_device_ids ? "ugdi_true" : "ugdi_false")
68+
: "nougdi");
69+
return s.str();
70+
}
71+
};
72+
73+
std::vector<TestCase> GetTestCases() {
74+
const std::vector<TestCase> test_cases = {
75+
// clang-format off
76+
// has_channel_id, use_global_device_ids, expected mode
77+
{false, std::nullopt, CollectiveOpGroupMode::kCrossReplica},
78+
{false, false, CollectiveOpGroupMode::kCrossReplica},
79+
{false, true, std::nullopt},
80+
{true, std::nullopt, CollectiveOpGroupMode::kCrossPartition},
81+
{true, false, CollectiveOpGroupMode::kCrossReplicaAndPartition},
82+
{true, true, CollectiveOpGroupMode::kFlattenedID},
83+
// clang-format on
84+
};
85+
return test_cases;
86+
}
87+
88+
class GetCollectOpGroupModeTest : public testing::TestWithParam<TestCase> {};
89+
90+
TEST_P(GetCollectOpGroupModeTest, Test) {
91+
const TestCase &tc = GetParam();
92+
absl::StatusOr<CollectiveOpGroupMode> actual =
93+
GetCollectiveOpGroupMode(tc.has_channel_id, tc.use_global_device_ids);
94+
if (tc.expected) {
95+
TF_ASSERT_OK(actual.status());
96+
EXPECT_EQ(*actual, *tc.expected);
97+
} else {
98+
EXPECT_FALSE(actual.ok());
99+
}
100+
}
101+
102+
INSTANTIATE_TEST_SUITE_P(GetCollectOpGroupMode, GetCollectOpGroupModeTest,
103+
testing::ValuesIn(GetTestCases()));
104+
} // namespace GetCollectiveOpGroupModeTest
105+
} // namespace
106+
} // namespace xla

third_party/xla/xla/hlo/parser/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ cc_library(
3636
"//xla:types",
3737
"//xla:util",
3838
"//xla:xla_data_proto_cc",
39+
"//xla/hlo/ir:collective_op_group_mode",
3940
"//xla/hlo/ir:hlo",
4041
"//xla/hlo/ir:tile_assignment",
4142
"//xla/service:computation_layout",
@@ -76,6 +77,7 @@ xla_cc_test(
7677
"//xla:window_util",
7778
"//xla:xla_data_proto_cc",
7879
"//xla/hlo/builder:xla_builder",
80+
"//xla/hlo/ir:collective_op_group_mode",
7981
"//xla/hlo/ir:hlo",
8082
"//xla/hlo/testlib:pattern_matcher_gmock",
8183
"//xla/hlo/testlib:verified_hlo_module",
@@ -92,6 +94,7 @@ xla_cc_test(
9294
"@com_google_absl//absl/status",
9395
"@com_google_absl//absl/status:statusor",
9496
"@com_google_absl//absl/strings",
97+
"@com_google_absl//absl/strings:str_format",
9598
"@com_google_absl//absl/types:span",
9699
"@com_google_googletest//:gtest",
97100
],

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy