Skip to content

Commit 42a44b1

Browse files
Prevent XLA from crashing when a Literal is too big to fit in memory.
The `Literal(Shape)` constructor is unsafe: it may crash when the shape is too large to fit in memory. Deprecate it in favor of a static factory `Literal::Make(Shape)`, which returns a `Status` on allocation failure instead of crashing. PiperOrigin-RevId: 786179524
1 parent c174396 commit 42a44b1

File tree

4 files changed

+61
-31
lines changed

4 files changed

+61
-31
lines changed

third_party/xla/xla/literal.cc

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,6 @@ Shape* MutableLiteralBase::mutable_shape_do_not_use() {
249249

250250
Literal::Literal() : Literal(NilShape()) {}
251251

252-
Literal::Literal(const Shape& shape)
253-
: Literal(shape, /*allocate_arrays=*/true) {}
254-
255252
void Literal::SetShape(const Shape& shape) {
256253
if (const Shape* intered_shape_ptr = TryInternShape(shape)) {
257254
shape_ = intered_shape_ptr;
@@ -273,15 +270,15 @@ void Literal::SetShape(const Shape& shape) {
273270
shape_ = std::move(owning_shape_ptr);
274271
}
275272

276-
void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays,
277-
ArrayValueState leaf_array_value_state) {
273+
absl::Status Literal::SetPiece(const Shape& shape, Piece* piece,
274+
bool allocate_arrays,
275+
ArrayValueState leaf_array_value_state) {
278276
if (shape.IsTuple()) {
279277
for (const Shape& subshape : shape.tuple_shapes()) {
280278
Piece child_piece;
281279
child_piece.set_subshape(&subshape);
282-
283-
SetPiece(subshape, &child_piece, allocate_arrays, leaf_array_value_state);
284-
280+
TF_RETURN_IF_ERROR(SetPiece(subshape, &child_piece, allocate_arrays,
281+
leaf_array_value_state));
285282
piece->emplace_back(std::move(child_piece));
286283
}
287284
} else if (shape.IsArray()) {
@@ -292,22 +289,34 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays,
292289
piece->set_array_value_state(leaf_array_value_state);
293290
if (leaf_array_value_state == LiteralBase::ArrayValueState::kKnown &&
294291
allocate_arrays) {
295-
piece->AllocateBuffers();
292+
TF_RETURN_IF_ERROR(piece->AllocateBuffers());
296293
}
297294
}
295+
return absl::OkStatus();
298296
}
299297

300-
Literal::Literal(const Shape& shape, bool allocate_arrays,
301-
ArrayValueState leaf_array_value_state) {
302-
SetShape(shape);
298+
absl::StatusOr<Literal> Literal::Make(
299+
const Shape& shape, const bool allocate_arrays,
300+
const ArrayValueState leaf_array_value_state) {
301+
// We cannot use the default constructor because it is implemented in terms
302+
// of Literal::Make(), leading to infinite recursion.
303+
Literal literal(UninitializedLiteralTag{});
304+
literal.SetShape(shape);
303305
CHECK(leaf_array_value_state != ArrayValueState::kKnown ||
304-
LayoutUtil::HasLayout(*shape_));
305-
root_piece_.set_subshape(shape_.get());
306-
CHECK(&root_piece_.subshape() == shape_.get());
306+
LayoutUtil::HasLayout(*literal.shape_));
307+
literal.root_piece_.set_subshape(literal.shape_.get());
308+
CHECK(&literal.root_piece_.subshape() == literal.shape_.get());
307309

308-
SetPiece(*shape_, &root_piece_, allocate_arrays, leaf_array_value_state);
310+
TF_RETURN_IF_ERROR(literal.SetPiece(*literal.shape_, &literal.root_piece_,
311+
allocate_arrays, leaf_array_value_state));
312+
return literal;
309313
}
310314

315+
Literal::Literal(const Shape& shape, bool allocate_arrays,
316+
ArrayValueState leaf_array_value_state)
317+
: Literal(Literal::Make(shape, allocate_arrays, leaf_array_value_state)
318+
.value()) {}
319+
311320
Literal::~Literal() { DeallocateBuffers(); }
312321

313322
void Literal::DeallocateBuffers() {
@@ -622,16 +631,20 @@ void LiteralBase::Piece::SetDynamicSize(int64_t dim_index, int32_t size) {
622631
dynamic_size_buffer()[dim_index] = size;
623632
}
624633

625-
void LiteralBase::Piece::AllocateBuffers() {
634+
absl::Status LiteralBase::Piece::AllocateBuffers() {
626635
const int64_t bytes = total_bytes_dense();
627636
if (bytes > kMaxInlinedBytes) {
628637
CHECK_EQ(buffer(), nullptr);
629638
storage_.Emplace<DenseRep>(
630639
static_cast<char*>(tsl::port::AlignedMalloc(bytes, kMinimumAlignment)));
631-
CHECK_NE(buffer(), nullptr) << "Failed to allocate buffer for Literal";
640+
if (buffer() == nullptr) {
641+
return absl::ResourceExhaustedError(
642+
"Failed to allocate buffer for Literal");
643+
}
632644
} else {
633645
storage_.Emplace<DenseInlinedRep>();
634646
}
647+
return absl::OkStatus();
635648
}
636649

637650
void LiteralBase::Piece::DeallocateBuffers() {
@@ -700,7 +713,7 @@ absl::Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src,
700713
CHECK(src.array_value_state_ == ArrayValueState::kKnown);
701714
if (array_value_state_ == ArrayValueState::kUndetermined ||
702715
array_value_state_ == ArrayValueState::kUnknown) {
703-
AllocateBuffers();
716+
TF_RETURN_IF_ERROR(AllocateBuffers());
704717
}
705718
array_value_state_ = src.array_value_state_;
706719
}

third_party/xla/xla/literal.h

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -857,7 +857,7 @@ class LiteralBase {
857857

858858
DynamicSizeType GetDynamicSize(int64_t dim_index) const;
859859
void SetDynamicSize(int64_t dim_index, DynamicSizeType size);
860-
void AllocateBuffers();
860+
absl::Status AllocateBuffers();
861861
void DeallocateBuffers();
862862
// Gets/sets the buffer holding the array data.
863863
const char* buffer() const;
@@ -1526,10 +1526,6 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal);
15261526
class Literal : public MutableLiteralBase {
15271527
public:
15281528
Literal();
1529-
1530-
// Create a literal of the given shape. The literal is allocated sufficient
1531-
// memory to hold the shape. Memory is uninitialized.
1532-
explicit Literal(const Shape& shape);
15331529
~Literal() override;
15341530

15351531
// Literals are moveable, but not copyable. To copy a literal use
@@ -1538,13 +1534,25 @@ class Literal : public MutableLiteralBase {
15381534
Literal(const Literal& other) = delete;
15391535
Literal& operator=(const Literal& other) = delete;
15401536
Literal(Literal&& other);
1537+
// Create a literal of the given shape.
15411538
// 'allocate_arrays' indicates whether to allocate memory for the arrays in
15421539
// the shape. If false, buffer pointers inside of the Literal::Pieces are set
1543-
// to nullptr.
1544-
Literal(const Shape& shape, bool allocate_arrays,
1545-
ArrayValueState leaf_array_value_state = ArrayValueState::kKnown);
1540+
// to nullptr. If true, the buffers are allocated but uninitialized.
1541+
ABSL_DEPRECATED(
1542+
"This ctor may crash if allocation fails. Use Literal::Make() instead.")
1543+
explicit Literal(
1544+
const Shape& shape, bool allocate_arrays = true,
1545+
ArrayValueState leaf_array_value_state = ArrayValueState::kKnown);
15461546
Literal& operator=(Literal&& other);
15471547

1548+
// Creates a literal of the given shape.
1549+
// 'allocate_arrays' indicates whether to allocate memory for the arrays in
1550+
// the shape. If false, buffer pointers inside of the Literal::Pieces are set
1551+
// to nullptr.
1552+
static absl::StatusOr<Literal> Make(
1553+
const Shape& shape, bool allocate_arrays = true,
1554+
ArrayValueState leaf_array_value_state = ArrayValueState::kKnown);
1555+
15481556
// Similar to CopyFrom, but with move semantics. The subshape of this literal
15491557
// rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
15501558
// (layouts and shapes must match), but need not be arrays. The memory
@@ -1586,6 +1594,11 @@ class Literal : public MutableLiteralBase {
15861594
private:
15871595
friend class LiteralBase;
15881596
friend class MutableLiteralBase;
1597+
struct UninitializedLiteralTag {};
1598+
1599+
// Creates an uninitialized literal.
1600+
explicit Literal(UninitializedLiteralTag) {}
1601+
15891602
const Piece& root_piece() const final { return root_piece_; };
15901603
// Deallocate the buffers held by this literal.
15911604
void DeallocateBuffers();
@@ -1598,7 +1611,7 @@ class Literal : public MutableLiteralBase {
15981611
// Recursively sets the subshapes and buffers of all subpieces rooted at
15991612
// 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
16001613
// the shape.
1601-
void SetPiece(
1614+
absl::Status SetPiece(
16021615
const Shape& shape, Piece* piece, bool allocate_arrays,
16031616
ArrayValueState leaf_array_value_state = ArrayValueState::kKnown);
16041617
Piece root_piece_;

third_party/xla/xla/pjrt/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ cc_library(
317317
"//xla/service:computation_placer_hdr",
318318
"//xla/service:hlo_cost_analysis",
319319
"//xla/tsl/framework:allocator",
320+
"//xla/tsl/platform:errors",
320321
"//xla/tsl/platform:statusor",
321322
"//xla/tsl/protobuf:coordination_service_proto_cc",
322323
"@com_google_absl//absl/base",

third_party/xla/xla/pjrt/pjrt_client.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ limitations under the License.
5656
#include "xla/shape.h"
5757
#include "xla/shape_util.h"
5858
#include "xla/tsl/framework/allocator.h"
59+
#include "xla/tsl/platform/errors.h"
60+
#include "xla/tsl/platform/statusor.h"
5961
#include "xla/tsl/protobuf/coordination_service.pb.h"
6062
#include "xla/util.h"
6163
#include "xla/xla_data.pb.h"
@@ -1128,9 +1130,10 @@ class PjRtBuffer {
11281130
// layout.
11291131
absl::StatusOr<std::shared_ptr<Literal>> ToLiteralSync() {
11301132
TF_ASSIGN_OR_RETURN(Shape host_shape, HostShape());
1131-
auto literal = std::make_shared<Literal>(host_shape);
1132-
TF_RETURN_IF_ERROR(ToLiteralSync(literal.get()));
1133-
return literal;
1133+
TF_ASSIGN_OR_RETURN(auto literal, Literal::Make(host_shape));
1134+
auto shared_literal = std::make_shared<Literal>(std::move(literal));
1135+
TF_RETURN_IF_ERROR(ToLiteralSync(shared_literal.get()));
1136+
return shared_literal;
11341137
}
11351138

11361139
// Returns the number of bytes of the buffer storage on the device.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy