Skip to content

Commit cf3dfa9

Browse files
reedwmGoogle-ML-Automation
authored andcommitted
Add 'mode' attribute to AllReduce and ReduceScatter.
This is an enum attribute which holds a CollectiveOpGroupMode. This enum determines how IDs in the replica_groups attribute are interpreted. Currently it is derived from the use_global_device_ids attribute and whether or not the channel_id attribute is present. The channel_id's integer value is not used in determine the mode: only whether it is present or not affects the mode. The channel_id's integer value is unused in all of XLA except for rare cases on TPUs. The long-term plan is to instead directly have instructions hold the CollectiveOpGroupMode, remove the use_global_device_ids attribute, and no longer use channel_id to determine the mode. An instruction will then only have a channel_id in the rare cases where it's value (and not just its presence) is used. This change is the first step of this plan. It adds the mode attribute to AllReduce and ReduceScatter. The constructor of such instructions takes an optional<CollectiveOpGroupMode> defaulting to nullopt, which causes the mode to be derived from channel_id and use_global_device_ids. The instructions store a non-optional CollectiveOpGroupMode. The HLO verifier will complain if the mode does not match channel_id and use_global_device_ids. The mode is optional in HLO text, and if not present, will be derived from channel_id and use_global_device_ids. CollectiveOpGroupMode is moved to its own file to avoid a circular dependency between collective_ops_utils.h and hlo_instructions.h. I created a duplicate CollectiveOpGroupModeProto enum, since we need a proto form for HloInstructionProto. Eventually we should get rid of the C++ enum and just use the proto enum. After this change, the next step is to add the attribute to other collectives. Then, change usages of functions like CreateAllReduce to pass the mode, and change passes and other code to directly use the instruction's mode rather than calculate it from channel_id and use_global_device IDs. The final step is to remove use_global_device_ids and to make channel_id unused except in the cases where its integer value, not just the existence of a nonzero int, is used. PiperOrigin-RevId: 783091244
1 parent c08a900 commit cf3dfa9

26 files changed

+680
-178
lines changed

xla/hlo/ir/BUILD

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ cc_library(
5959
],
6060
deps = [
6161
":backend_config",
62+
":collective_op_group_mode",
6263
":hlo_op_metadata",
6364
":hlo_sharding",
6465
":ptrvec",
@@ -374,3 +375,28 @@ xla_cc_test(
374375
"@com_google_googletest//:gtest",
375376
],
376377
)
378+
379+
cc_library(
380+
name = "collective_op_group_mode",
381+
srcs = ["collective_op_group_mode.cc"],
382+
hdrs = ["collective_op_group_mode.h"],
383+
deps = [
384+
"//xla:util",
385+
"//xla:xla_data_proto_cc",
386+
"@com_google_absl//absl/log:check",
387+
"@com_google_absl//absl/status:statusor",
388+
"@com_google_absl//absl/strings:string_view",
389+
],
390+
)
391+
392+
xla_cc_test(
393+
name = "collective_op_group_mode_test",
394+
srcs = ["collective_op_group_mode_test.cc"],
395+
deps = [
396+
":collective_op_group_mode",
397+
"//xla/tests:xla_internal_test_main",
398+
"//xla/tsl/lib/core:status_test_util",
399+
"@com_google_absl//absl/status:statusor",
400+
"@com_google_googletest//:gtest",
401+
],
402+
)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/hlo/ir/collective_op_group_mode.h"
17+
18+
#include <optional>
19+
20+
#include "absl/log/check.h"
21+
#include "absl/status/statusor.h"
22+
#include "absl/strings/string_view.h"
23+
#include "xla/util.h"
24+
25+
namespace xla {
26+
namespace {
27+
28+
struct CollectiveOpGroupModeInfo {
29+
CollectiveOpGroupMode mode;
30+
absl::string_view name;
31+
CollectiveOpGroupModeProto proto;
32+
};
33+
34+
const CollectiveOpGroupModeInfo kGroupModeInfos[] = {
35+
{CollectiveOpGroupMode::kCrossReplica, "cross_replica",
36+
CollectiveOpGroupModeProto::COLLECTIVE_MODE_CROSS_REPLICA},
37+
{CollectiveOpGroupMode::kCrossPartition, "cross_partition",
38+
CollectiveOpGroupModeProto::COLLECTIVE_MODE_CROSS_PARTITION},
39+
{CollectiveOpGroupMode::kCrossReplicaAndPartition,
40+
"cross_replica_and_partition",
41+
CollectiveOpGroupModeProto::COLLECTIVE_MODE_CROSS_REPLICA_AND_PARTITION},
42+
{CollectiveOpGroupMode::kFlattenedID, "flattened_id",
43+
CollectiveOpGroupModeProto::COLLECTIVE_MODE_FLATTENED_ID},
44+
};
45+
46+
} // namespace
47+
48+
absl::string_view CollectiveOpGroupModeToString(
49+
CollectiveOpGroupMode group_mode) {
50+
for (const CollectiveOpGroupModeInfo& info : kGroupModeInfos) {
51+
if (info.mode == group_mode) {
52+
return info.name;
53+
}
54+
}
55+
CHECK(false) << "Unknown collective op group mode: "
56+
<< static_cast<int>(group_mode);
57+
}
58+
59+
absl::StatusOr<CollectiveOpGroupMode> StringToCollectiveOpGroupMode(
60+
absl::string_view name) {
61+
for (const CollectiveOpGroupModeInfo& info : kGroupModeInfos) {
62+
if (info.name == name) {
63+
return info.mode;
64+
}
65+
}
66+
return InvalidArgument("Invalid collective op group mode: %s", name);
67+
}
68+
69+
CollectiveOpGroupModeProto CollectiveOpGroupModeToProto(
70+
CollectiveOpGroupMode group_mode) {
71+
for (const CollectiveOpGroupModeInfo& info : kGroupModeInfos) {
72+
if (info.mode == group_mode) {
73+
return info.proto;
74+
}
75+
}
76+
CHECK(false) << "Unknown collective op group mode: "
77+
<< static_cast<int>(group_mode);
78+
}
79+
80+
absl::StatusOr<CollectiveOpGroupMode> CollectiveOpGroupModeFromProto(
81+
CollectiveOpGroupModeProto proto) {
82+
for (const CollectiveOpGroupModeInfo& info : kGroupModeInfos) {
83+
if (info.proto == proto) {
84+
return info.mode;
85+
}
86+
}
87+
return InvalidArgument("Invalid collective op group mode proto: %s",
88+
CollectiveOpGroupModeProto_Name(proto));
89+
}
90+
91+
// Returns the group formation mode implied by (a) whether the operation has
92+
// channel_id and (b) if it has use_global_device_ids and if yes, its value.
93+
absl::StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
94+
bool has_channel_id, std::optional<bool> use_global_device_ids) {
95+
if (!has_channel_id) {
96+
if (use_global_device_ids.has_value() && *use_global_device_ids) {
97+
return InvalidArgument(
98+
"Cannot have use_global_device_ids=true without channel_id");
99+
}
100+
return CollectiveOpGroupMode::kCrossReplica;
101+
}
102+
if (!use_global_device_ids.has_value()) {
103+
return CollectiveOpGroupMode::kCrossPartition;
104+
}
105+
if (!*use_global_device_ids) {
106+
return CollectiveOpGroupMode::kCrossReplicaAndPartition;
107+
}
108+
return CollectiveOpGroupMode::kFlattenedID;
109+
}
110+
111+
} // namespace xla

xla/hlo/ir/collective_op_group_mode.h

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_HLO_IR_COLLECTIVE_OP_GROUP_MODE_H_
17+
#define XLA_HLO_IR_COLLECTIVE_OP_GROUP_MODE_H_
18+
19+
#include <optional>
20+
21+
#include "absl/status/statusor.h"
22+
#include "absl/strings/string_view.h"
23+
#include "xla/xla_data.pb.h"
24+
25+
namespace xla {
26+
27+
// There are broadly 4 modes that collective communication ops use to describe
28+
// which sets of devices are participating with a given device in the operation.
29+
// These modes are determined by the values of channel_id (optional) and
30+
// use_global_device_ids (optional). The modes are as follows:
31+
//
32+
// kCrossReplica:
33+
// implied by: no channel id, use_global_device_ids = false, or
34+
// no channel_id, no use_global_device_ids:
35+
// replica_groups contain replica_id, group contains all replicas for the
36+
// current partition
37+
//
38+
// kCrossPartition:
39+
// implied by: channel_id is set, no use_global_device_ids:
40+
// replica_groups contain partition_id, group contains all partitions for the
41+
// current replica.
42+
//
43+
// kCrossReplicaAndPartition:
44+
// implied by: channel_id is set, use_global_device_ids = false:
45+
// replica_groups contain replica_id, group contains all replicas for all
46+
// partitions (as opposed to just current partition).
47+
//
48+
// kFlattenedID:
49+
// implied by: channel_id is set, use_global_device_ids = true:
50+
// replica_groups contain flattened-ids, group contains devices that are
51+
// listed in the flattened-id list.
52+
//
53+
// Rest of the combinations are invalid.
54+
//
55+
// Since the actual value of channel_id does not matter, we use a bool argument
56+
// `has_channel_id`, and optional<bool> for use_global_device_ids.
57+
// Note that use_global_device_ids true requires channel_id to be set as well.
58+
// Additionally, if use_global_device_ids = true, replica groups cannot be
59+
// empty (verified in the HLO verifier).
60+
enum class CollectiveOpGroupMode {
61+
kCrossReplica,
62+
kCrossPartition,
63+
kCrossReplicaAndPartition,
64+
kFlattenedID,
65+
};
66+
67+
absl::string_view CollectiveOpGroupModeToString(
68+
CollectiveOpGroupMode group_mode);
69+
70+
absl::StatusOr<CollectiveOpGroupMode> StringToCollectiveOpGroupMode(
71+
absl::string_view name);
72+
73+
CollectiveOpGroupModeProto CollectiveOpGroupModeToProto(
74+
CollectiveOpGroupMode group_mode);
75+
76+
absl::StatusOr<CollectiveOpGroupMode> CollectiveOpGroupModeFromProto(
77+
CollectiveOpGroupModeProto proto);
78+
79+
// Returns the group formation mode implied by (a) whether the operation has
80+
// channel_id and (b) if it has use_global_device_ids and if yes, its value.
81+
absl::StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
82+
bool has_channel_id, std::optional<bool> use_global_device_ids);
83+
84+
} // namespace xla
85+
86+
#endif // XLA_HLO_IR_COLLECTIVE_OP_GROUP_MODE_H_
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/hlo/ir/collective_op_group_mode.h"
17+
18+
#include <optional>
19+
#include <sstream>
20+
#include <string>
21+
#include <vector>
22+
23+
#include <gtest/gtest.h>
24+
#include "absl/status/statusor.h"
25+
#include "xla/tsl/lib/core/status_test_util.h"
26+
27+
namespace xla {
28+
namespace {
29+
30+
TEST(CollectiveOpGroupModeTest, ToString) {
31+
EXPECT_EQ(CollectiveOpGroupModeToString(CollectiveOpGroupMode::kCrossReplica),
32+
"cross_replica");
33+
EXPECT_EQ(
34+
CollectiveOpGroupModeToString(CollectiveOpGroupMode::kCrossPartition),
35+
"cross_partition");
36+
EXPECT_EQ(CollectiveOpGroupModeToString(
37+
CollectiveOpGroupMode::kCrossReplicaAndPartition),
38+
"cross_replica_and_partition");
39+
EXPECT_EQ(CollectiveOpGroupModeToString(CollectiveOpGroupMode::kFlattenedID),
40+
"flattened_id");
41+
}
42+
43+
TEST(CollectiveOpGroupModeTest, FromString) {
44+
EXPECT_EQ(StringToCollectiveOpGroupMode("cross_replica").value(),
45+
CollectiveOpGroupMode::kCrossReplica);
46+
EXPECT_EQ(StringToCollectiveOpGroupMode("cross_partition").value(),
47+
CollectiveOpGroupMode::kCrossPartition);
48+
EXPECT_EQ(
49+
StringToCollectiveOpGroupMode("cross_replica_and_partition").value(),
50+
CollectiveOpGroupMode::kCrossReplicaAndPartition);
51+
EXPECT_EQ(StringToCollectiveOpGroupMode("flattened_id").value(),
52+
CollectiveOpGroupMode::kFlattenedID);
53+
}
54+
55+
TEST(CollectiveOpGroupModeTest, ToProto) {
56+
EXPECT_EQ(CollectiveOpGroupModeToProto(CollectiveOpGroupMode::kCrossReplica),
57+
CollectiveOpGroupModeProto::COLLECTIVE_MODE_CROSS_REPLICA);
58+
EXPECT_EQ(
59+
CollectiveOpGroupModeToProto(CollectiveOpGroupMode::kCrossPartition),
60+
CollectiveOpGroupModeProto::COLLECTIVE_MODE_CROSS_PARTITION);
61+
EXPECT_EQ(
62+
CollectiveOpGroupModeToProto(
63+
CollectiveOpGroupMode::kCrossReplicaAndPartition),
64+
CollectiveOpGroupModeProto::COLLECTIVE_MODE_CROSS_REPLICA_AND_PARTITION);
65+
EXPECT_EQ(CollectiveOpGroupModeToProto(CollectiveOpGroupMode::kFlattenedID),
66+
CollectiveOpGroupModeProto::COLLECTIVE_MODE_FLATTENED_ID);
67+
}
68+
69+
TEST(CollectiveOpGroupModeTest, FromProto) {
70+
EXPECT_EQ(CollectiveOpGroupModeFromProto(
71+
CollectiveOpGroupModeProto::COLLECTIVE_MODE_CROSS_REPLICA)
72+
.value(),
73+
CollectiveOpGroupMode::kCrossReplica);
74+
EXPECT_EQ(CollectiveOpGroupModeFromProto(
75+
CollectiveOpGroupModeProto::COLLECTIVE_MODE_CROSS_PARTITION)
76+
.value(),
77+
CollectiveOpGroupMode::kCrossPartition);
78+
EXPECT_EQ(CollectiveOpGroupModeFromProto(
79+
CollectiveOpGroupModeProto::
80+
COLLECTIVE_MODE_CROSS_REPLICA_AND_PARTITION)
81+
.value(),
82+
CollectiveOpGroupMode::kCrossReplicaAndPartition);
83+
EXPECT_EQ(CollectiveOpGroupModeFromProto(
84+
CollectiveOpGroupModeProto::COLLECTIVE_MODE_FLATTENED_ID)
85+
.value(),
86+
CollectiveOpGroupMode::kFlattenedID);
87+
}
88+
89+
// Tests for GetCollectOpGroupMode
90+
namespace GetCollectiveOpGroupModeTest {
91+
struct TestCase {
92+
bool has_channel_id;
93+
std::optional<bool> use_global_device_ids;
94+
std::optional<xla::CollectiveOpGroupMode> expected;
95+
96+
std::string ToString() const {
97+
std::ostringstream s;
98+
s << (has_channel_id ? "chnl" : "nochnl");
99+
s << "_"
100+
<< (use_global_device_ids
101+
? (*use_global_device_ids ? "ugdi_true" : "ugdi_false")
102+
: "nougdi");
103+
return s.str();
104+
}
105+
};
106+
107+
std::vector<TestCase> GetTestCases() {
108+
const std::vector<TestCase> test_cases = {
109+
// clang-format off
110+
// has_channel_id, use_global_device_ids, expected mode
111+
{false, std::nullopt, CollectiveOpGroupMode::kCrossReplica},
112+
{false, false, CollectiveOpGroupMode::kCrossReplica},
113+
{false, true, std::nullopt},
114+
{true, std::nullopt, CollectiveOpGroupMode::kCrossPartition},
115+
{true, false, CollectiveOpGroupMode::kCrossReplicaAndPartition},
116+
{true, true, CollectiveOpGroupMode::kFlattenedID},
117+
// clang-format on
118+
};
119+
return test_cases;
120+
}
121+
122+
class GetCollectOpGroupModeTest : public testing::TestWithParam<TestCase> {};
123+
124+
TEST_P(GetCollectOpGroupModeTest, Test) {
125+
const TestCase &tc = GetParam();
126+
absl::StatusOr<CollectiveOpGroupMode> actual =
127+
GetCollectiveOpGroupMode(tc.has_channel_id, tc.use_global_device_ids);
128+
if (tc.expected) {
129+
TF_ASSERT_OK(actual.status());
130+
EXPECT_EQ(*actual, *tc.expected);
131+
} else {
132+
EXPECT_FALSE(actual.ok());
133+
}
134+
}
135+
136+
INSTANTIATE_TEST_SUITE_P(GetCollectOpGroupMode, GetCollectOpGroupModeTest,
137+
testing::ValuesIn(GetTestCases()));
138+
} // namespace GetCollectiveOpGroupModeTest
139+
} // namespace
140+
} // namespace xla

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy