Skip to content

Prevent XLA from crashing when a Literal is too big to fit in memory. #97400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 32 additions & 19 deletions third_party/xla/xla/literal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,6 @@ Shape* MutableLiteralBase::mutable_shape_do_not_use() {

Literal::Literal() : Literal(NilShape()) {}

Literal::Literal(const Shape& shape)
: Literal(shape, /*allocate_arrays=*/true) {}

void Literal::SetShape(const Shape& shape) {
if (const Shape* intered_shape_ptr = TryInternShape(shape)) {
shape_ = intered_shape_ptr;
Expand All @@ -273,15 +270,15 @@ void Literal::SetShape(const Shape& shape) {
shape_ = std::move(owning_shape_ptr);
}

void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays,
ArrayValueState leaf_array_value_state) {
absl::Status Literal::SetPiece(const Shape& shape, Piece* piece,
bool allocate_arrays,
ArrayValueState leaf_array_value_state) {
if (shape.IsTuple()) {
for (const Shape& subshape : shape.tuple_shapes()) {
Piece child_piece;
child_piece.set_subshape(&subshape);

SetPiece(subshape, &child_piece, allocate_arrays, leaf_array_value_state);

TF_RETURN_IF_ERROR(SetPiece(subshape, &child_piece, allocate_arrays,
leaf_array_value_state));
piece->emplace_back(std::move(child_piece));
}
} else if (shape.IsArray()) {
Expand All @@ -292,22 +289,34 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays,
piece->set_array_value_state(leaf_array_value_state);
if (leaf_array_value_state == LiteralBase::ArrayValueState::kKnown &&
allocate_arrays) {
piece->AllocateBuffers();
TF_RETURN_IF_ERROR(piece->AllocateBuffers());
}
}
return absl::OkStatus();
}

Literal::Literal(const Shape& shape, bool allocate_arrays,
ArrayValueState leaf_array_value_state) {
SetShape(shape);
absl::StatusOr<Literal> Literal::Make(
const Shape& shape, const bool allocate_arrays,
const ArrayValueState leaf_array_value_state) {
// We cannot use the default constructor because it is implemented in terms
// of Literal::Make(), leading to infinite recursion.
Literal literal(UninitializedLiteralTag{});
literal.SetShape(shape);
CHECK(leaf_array_value_state != ArrayValueState::kKnown ||
LayoutUtil::HasLayout(*shape_));
root_piece_.set_subshape(shape_.get());
CHECK(&root_piece_.subshape() == shape_.get());
LayoutUtil::HasLayout(*literal.shape_));
literal.root_piece_.set_subshape(literal.shape_.get());
CHECK(&literal.root_piece_.subshape() == literal.shape_.get());

SetPiece(*shape_, &root_piece_, allocate_arrays, leaf_array_value_state);
TF_RETURN_IF_ERROR(literal.SetPiece(*literal.shape_, &literal.root_piece_,
allocate_arrays, leaf_array_value_state));
return literal;
}

Literal::Literal(const Shape& shape, bool allocate_arrays,
ArrayValueState leaf_array_value_state)
: Literal(Literal::Make(shape, allocate_arrays, leaf_array_value_state)
.value()) {}

Literal::~Literal() { DeallocateBuffers(); }

void Literal::DeallocateBuffers() {
Expand Down Expand Up @@ -622,16 +631,20 @@ void LiteralBase::Piece::SetDynamicSize(int64_t dim_index, int32_t size) {
dynamic_size_buffer()[dim_index] = size;
}

void LiteralBase::Piece::AllocateBuffers() {
absl::Status LiteralBase::Piece::AllocateBuffers() {
const int64_t bytes = total_bytes_dense();
if (bytes > kMaxInlinedBytes) {
CHECK_EQ(buffer(), nullptr);
storage_.Emplace<DenseRep>(
static_cast<char*>(tsl::port::AlignedMalloc(bytes, kMinimumAlignment)));
CHECK_NE(buffer(), nullptr) << "Failed to allocate buffer for Literal";
if (buffer() == nullptr) {
return absl::ResourceExhaustedError(
"Failed to allocate buffer for Literal");
}
} else {
storage_.Emplace<DenseInlinedRep>();
}
return absl::OkStatus();
}

void LiteralBase::Piece::DeallocateBuffers() {
Expand Down Expand Up @@ -700,7 +713,7 @@ absl::Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src,
CHECK(src.array_value_state_ == ArrayValueState::kKnown);
if (array_value_state_ == ArrayValueState::kUndetermined ||
array_value_state_ == ArrayValueState::kUnknown) {
AllocateBuffers();
TF_RETURN_IF_ERROR(AllocateBuffers());
}
array_value_state_ = src.array_value_state_;
}
Expand Down
31 changes: 22 additions & 9 deletions third_party/xla/xla/literal.h
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ class LiteralBase {

DynamicSizeType GetDynamicSize(int64_t dim_index) const;
void SetDynamicSize(int64_t dim_index, DynamicSizeType size);
void AllocateBuffers();
absl::Status AllocateBuffers();
void DeallocateBuffers();
// Gets/sets the buffer holding the array data.
const char* buffer() const;
Expand Down Expand Up @@ -1526,10 +1526,6 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal);
class Literal : public MutableLiteralBase {
public:
Literal();

// Create a literal of the given shape. The literal is allocated sufficient
// memory to hold the shape. Memory is uninitialized.
explicit Literal(const Shape& shape);
~Literal() override;

// Literals are moveable, but not copyable. To copy a literal use
Expand All @@ -1538,13 +1534,25 @@ class Literal : public MutableLiteralBase {
Literal(const Literal& other) = delete;
Literal& operator=(const Literal& other) = delete;
Literal(Literal&& other);
// Create a literal of the given shape.
// 'allocate_arrays' indicates whether to allocate memory for the arrays in
// the shape. If false, buffer pointers inside of the Literal::Pieces are set
// to nullptr.
Literal(const Shape& shape, bool allocate_arrays,
ArrayValueState leaf_array_value_state = ArrayValueState::kKnown);
// to nullptr. If true, the buffers are allocated but uninitialized.
ABSL_DEPRECATED(
"This ctor may crash if allocation fails. Use Literal::Make() instead.")
explicit Literal(
const Shape& shape, bool allocate_arrays = true,
ArrayValueState leaf_array_value_state = ArrayValueState::kKnown);
Literal& operator=(Literal&& other);

// Creates a literal of the given shape.
// 'allocate_arrays' indicates whether to allocate memory for the arrays in
// the shape. If false, buffer pointers inside of the Literal::Pieces are set
// to nullptr.
static absl::StatusOr<Literal> Make(
const Shape& shape, bool allocate_arrays = true,
ArrayValueState leaf_array_value_state = ArrayValueState::kKnown);

// Similar to CopyFrom, but with move semantics. The subshape of this literal
// rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
// (layouts and shapes must match), but need not be arrays. The memory
Expand Down Expand Up @@ -1586,6 +1594,11 @@ class Literal : public MutableLiteralBase {
private:
friend class LiteralBase;
friend class MutableLiteralBase;
struct UninitializedLiteralTag {};

// Creates an uninitialized literal.
explicit Literal(UninitializedLiteralTag) {}

const Piece& root_piece() const final { return root_piece_; };
// Deallocate the buffers held by this literal.
void DeallocateBuffers();
Expand All @@ -1598,7 +1611,7 @@ class Literal : public MutableLiteralBase {
// Recursively sets the subshapes and buffers of all subpieces rooted at
// 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
// the shape.
void SetPiece(
absl::Status SetPiece(
const Shape& shape, Piece* piece, bool allocate_arrays,
ArrayValueState leaf_array_value_state = ArrayValueState::kKnown);
Piece root_piece_;
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ cc_library(
"//xla/service:computation_placer_hdr",
"//xla/service:hlo_cost_analysis",
"//xla/tsl/framework:allocator",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"//xla/tsl/protobuf:coordination_service_proto_cc",
"@com_google_absl//absl/base",
Expand Down
9 changes: 6 additions & 3 deletions third_party/xla/xla/pjrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/framework/allocator.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/protobuf/coordination_service.pb.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
Expand Down Expand Up @@ -1128,9 +1130,10 @@ class PjRtBuffer {
// layout.
absl::StatusOr<std::shared_ptr<Literal>> ToLiteralSync() {
TF_ASSIGN_OR_RETURN(Shape host_shape, HostShape());
auto literal = std::make_shared<Literal>(host_shape);
TF_RETURN_IF_ERROR(ToLiteralSync(literal.get()));
return literal;
TF_ASSIGN_OR_RETURN(auto literal, Literal::Make(host_shape));
auto shared_literal = std::make_shared<Literal>(std::move(literal));
TF_RETURN_IF_ERROR(ToLiteralSync(shared_literal.get()));
return shared_literal;
}

// Returns the number of bytes of the buffer storage on the device.
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy