@@ -21,6 +21,7 @@ limitations under the License.
21
21
#include < cstdint>
22
22
#include < iterator>
23
23
#include < memory>
24
+ #include < optional>
24
25
#include < type_traits>
25
26
#include < utility>
26
27
@@ -30,6 +31,7 @@ limitations under the License.
30
31
#include " absl/status/status.h"
31
32
#include " absl/status/statusor.h"
32
33
#include " absl/types/span.h"
34
+ #include " absl/utility/utility.h"
33
35
#include " xla/shape.h"
34
36
#include " xla/shape_util.h"
35
37
#include " xla/tsl/lib/gtl/iterator_range.h"
@@ -55,12 +57,25 @@ class IndexTable {
55
57
IndexTable () = default ;
56
58
explicit IndexTable (const Shape& shape);
57
59
58
- bool empty () const { return entries_.empty (); }
60
+ const Entry& operator [](ShapeIndexView index) const {
61
+ static constexpr Entry kRootEntry = {0 , -1 };
59
62
60
- const Entry& operator [](ShapeIndexView index) const ;
63
+ if (!entries_.has_value ()) {
64
+ DCHECK (index.empty ());
65
+ return kRootEntry ;
66
+ }
67
+
68
+ const Entry* result = &entries_->front ();
69
+ for (int64_t i : index) {
70
+ DCHECK_GE (result->children_start_id , 0 );
71
+ result = &(*entries_)[result->children_start_id + i];
72
+ }
73
+ return *result;
74
+ }
61
75
62
76
private:
63
- absl::InlinedVector<Entry, 1 > entries_;
77
+ // Entries are computed only if the shape is a tuple.
78
+ std::optional<absl::InlinedVector<Entry, 1 >> entries_;
64
79
};
65
80
66
81
} // namespace internal
@@ -111,14 +126,14 @@ class ShapeTree {
111
126
: ShapeTree(std::make_shared<Shape>(std::move(shape))) {}
112
127
113
128
explicit ShapeTree (const Shape* shape)
114
- : ShapeTree(shape, CreateNodes(* shape) ) {}
129
+ : ShapeTree(absl:: in_place_t {}, shape) {}
115
130
116
131
// Create ShapeTree with the given shape, and init_value for all nodes.
117
132
ShapeTree (Shape shape, const T& init_value)
118
133
: ShapeTree(std::make_shared<Shape>(std::move(shape)), init_value) {}
119
134
120
135
ShapeTree (const Shape* shape, const T& init_value)
121
- : ShapeTree(shape, CreateNodes(* shape, init_value) ) {}
136
+ : ShapeTree(absl:: in_place_t {}, shape, init_value) {}
122
137
123
138
// Returns the data element associated with the array in the shape at the
124
139
// given index (see ShapeUtil::GetSubshape for how indexes are defined).
@@ -375,21 +390,20 @@ class ShapeTree {
375
390
}
376
391
377
392
template <typename ... Ts>
378
- static Nodes CreateNodes ( const Shape& shape, Ts&&... args) {
379
- Nodes nodes;
380
- nodes .reserve (ShapeUtil::SubshapeCount (shape));
393
+ ShapeTree (absl:: in_place_t , const Shape* shape, Ts&&... args)
394
+ : index_table_(*shape), shape_(shape) {
395
+ nodes_ .reserve (ShapeUtil::SubshapeCount (* shape));
381
396
ShapeUtil::ForEachSubshape (
382
- shape, [&](const Shape&, const ShapeIndex& index) {
383
- nodes .emplace_back (index, T (std::forward<Ts>(args)...));
397
+ * shape, [&](const Shape&, const ShapeIndex& index) {
398
+ nodes_ .emplace_back (index, T (std::forward<Ts>(args)...));
384
399
});
385
- return nodes;
386
400
}
387
401
388
402
// The nodes in this shape tree.
389
403
Nodes nodes_;
390
404
391
405
// Index table for node lookups. Each entry contains the index of the first
392
- // child of the node at that index, or -1 for leaf nodes. Evaluated lazily.
406
+ // child of the node at that index, or -1 for leaf nodes.
393
407
IndexTable index_table_;
394
408
395
409
// If we own our Shape, this field contains it, and shape_ is a pointer into
0 commit comments