Skip to content

Commit 2ddb5fd

Browse files
ezhulenevtensorflower-gardener
authored andcommitted
[xla] Optimize ShapeTree construction time for arrays
1. Add inplace constructor that avoids InlinedVector copy constructor 2. Do not compute index table for non-tuple types PiperOrigin-RevId: 786322120
1 parent b193e8d commit 2ddb5fd

File tree

4 files changed

+42
-41
lines changed

4 files changed

+42
-41
lines changed

third_party/xla/xla/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,7 @@ cc_library(
10371037
"@com_google_absl//absl/status",
10381038
"@com_google_absl//absl/status:statusor",
10391039
"@com_google_absl//absl/types:span",
1040+
"@com_google_absl//absl/utility",
10401041
],
10411042
)
10421043

third_party/xla/xla/shape_tree.cc

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,22 @@ limitations under the License.
1616
#include "xla/shape_tree.h"
1717

1818
#include <cstddef>
19-
#include <cstdint>
2019

2120
#include "absl/base/optimization.h"
2221
#include "absl/log/check.h"
2322
#include "absl/types/span.h"
2423
#include "xla/shape.h"
25-
#include "xla/shape_util.h"
26-
#include "xla/tsl/platform/logging.h" // IWYU pragma: keep
2724

28-
namespace xla {
29-
namespace internal {
25+
namespace xla::internal {
3026

31-
// Computes the total size of all nested tuples in the given shape.
32-
//
33-
// If `is_known_tuple` is true, then the shape is known to be a tuple, and we
34-
// can skip the run time check for `IsTuple()`.
35-
template <bool is_known_tuple = false>
27+
// Computes the total size of all nested tuples in the given tuple shape.
3628
static size_t IndexTableTuplesSize(const Shape& shape) {
37-
if (!is_known_tuple && ABSL_PREDICT_TRUE(!shape.IsTuple())) {
38-
return 0;
39-
}
29+
DCHECK(shape.IsTuple()) << "Shape must be a tuple";
4030

4131
size_t size = shape.tuple_shapes().size();
4232
for (const Shape& subshape : shape.tuple_shapes()) {
4333
if (ABSL_PREDICT_FALSE(subshape.IsTuple())) {
44-
size += IndexTableTuplesSize</*is_known_tuple=*/true>(subshape);
34+
size += IndexTableTuplesSize(subshape);
4535
}
4636
}
4737

@@ -75,22 +65,18 @@ static void InitializeIndexTable(const Shape& shape,
7565
}
7666
}
7767

78-
IndexTable::IndexTable(const Shape& shape)
79-
: entries_(1 + IndexTableTuplesSize(shape)) {
68+
IndexTable::IndexTable(const Shape& shape) {
69+
if (!shape.IsTuple()) {
70+
return;
71+
}
72+
73+
// Allocate storage for the index table.
74+
entries_.emplace(1 + IndexTableTuplesSize(shape));
75+
8076
size_t next_node_id = 0;
8177
size_t next_children_start_index = 1;
82-
InitializeIndexTable(shape, absl::MakeSpan(entries_), 0, next_node_id,
78+
InitializeIndexTable(shape, absl::MakeSpan(*entries_), 0, next_node_id,
8379
next_children_start_index);
8480
}
8581

86-
const IndexTable::Entry& IndexTable::operator[](ShapeIndexView index) const {
87-
const Entry* result = &entries_.front();
88-
for (int64_t i : index) {
89-
CHECK_GE(result->children_start_id, 0);
90-
result = &entries_[result->children_start_id + i];
91-
}
92-
return *result;
93-
}
94-
95-
} // namespace internal
96-
} // namespace xla
82+
} // namespace xla::internal

third_party/xla/xla/shape_tree.h

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#include <cstdint>
2222
#include <iterator>
2323
#include <memory>
24+
#include <optional>
2425
#include <type_traits>
2526
#include <utility>
2627

@@ -30,6 +31,7 @@ limitations under the License.
3031
#include "absl/status/status.h"
3132
#include "absl/status/statusor.h"
3233
#include "absl/types/span.h"
34+
#include "absl/utility/utility.h"
3335
#include "xla/shape.h"
3436
#include "xla/shape_util.h"
3537
#include "xla/tsl/lib/gtl/iterator_range.h"
@@ -55,12 +57,25 @@ class IndexTable {
5557
IndexTable() = default;
5658
explicit IndexTable(const Shape& shape);
5759

58-
bool empty() const { return entries_.empty(); }
60+
const Entry& operator[](ShapeIndexView index) const {
61+
static constexpr Entry kRootEntry = {0, -1};
5962

60-
const Entry& operator[](ShapeIndexView index) const;
63+
if (!entries_.has_value()) {
64+
DCHECK(index.empty());
65+
return kRootEntry;
66+
}
67+
68+
const Entry* result = &entries_->front();
69+
for (int64_t i : index) {
70+
DCHECK_GE(result->children_start_id, 0);
71+
result = &(*entries_)[result->children_start_id + i];
72+
}
73+
return *result;
74+
}
6175

6276
private:
63-
absl::InlinedVector<Entry, 1> entries_;
77+
// Entries are computed only if the shape is a tuple.
78+
std::optional<absl::InlinedVector<Entry, 1>> entries_;
6479
};
6580

6681
} // namespace internal
@@ -111,14 +126,14 @@ class ShapeTree {
111126
: ShapeTree(std::make_shared<Shape>(std::move(shape))) {}
112127

113128
explicit ShapeTree(const Shape* shape)
114-
: ShapeTree(shape, CreateNodes(*shape)) {}
129+
: ShapeTree(absl::in_place_t{}, shape) {}
115130

116131
// Create ShapeTree with the given shape, and init_value for all nodes.
117132
ShapeTree(Shape shape, const T& init_value)
118133
: ShapeTree(std::make_shared<Shape>(std::move(shape)), init_value) {}
119134

120135
ShapeTree(const Shape* shape, const T& init_value)
121-
: ShapeTree(shape, CreateNodes(*shape, init_value)) {}
136+
: ShapeTree(absl::in_place_t{}, shape, init_value) {}
122137

123138
// Returns the data element associated with the array in the shape at the
124139
// given index (see ShapeUtil::GetSubshape for how indexes are defined).
@@ -375,21 +390,20 @@ class ShapeTree {
375390
}
376391

377392
template <typename... Ts>
378-
static Nodes CreateNodes(const Shape& shape, Ts&&... args) {
379-
Nodes nodes;
380-
nodes.reserve(ShapeUtil::SubshapeCount(shape));
393+
ShapeTree(absl::in_place_t, const Shape* shape, Ts&&... args)
394+
: index_table_(*shape), shape_(shape) {
395+
nodes_.reserve(ShapeUtil::SubshapeCount(*shape));
381396
ShapeUtil::ForEachSubshape(
382-
shape, [&](const Shape&, const ShapeIndex& index) {
383-
nodes.emplace_back(index, T(std::forward<Ts>(args)...));
397+
*shape, [&](const Shape&, const ShapeIndex& index) {
398+
nodes_.emplace_back(index, T(std::forward<Ts>(args)...));
384399
});
385-
return nodes;
386400
}
387401

388402
// The nodes in this shape tree.
389403
Nodes nodes_;
390404

391405
// Index table for node lookups. Each entry contains the index of the first
392-
// child of the node at that index, or -1 for leaf nodes. Evaluated lazily.
406+
// child of the node at that index, or -1 for leaf nodes.
393407
IndexTable index_table_;
394408

395409
// If we own our Shape, this field contains it, and shape_ is a pointer into

third_party/xla/xla/shape_tree_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ void BM_Iterate(::testing::benchmark::State& state) {
654654
}
655655

656656
#define BENCHMARK_WITH_ARGS(name) \
657-
BENCHMARK(name)->ArgPair(2, 8)->ArgPair(1, 1000)
657+
BENCHMARK(name)->ArgPair(0, 0)->ArgPair(2, 8)->ArgPair(1, 1000)
658658

659659
BENCHMARK_WITH_ARGS(BM_Construct);
660660
BENCHMARK_WITH_ARGS(BM_ConstructUnowned);

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy