@@ -249,9 +249,6 @@ Shape* MutableLiteralBase::mutable_shape_do_not_use() {
249
249
250
250
Literal::Literal () : Literal(NilShape()) {}
251
251
252
- Literal::Literal (const Shape& shape)
253
- : Literal(shape, /* allocate_arrays=*/ true ) {}
254
-
255
252
void Literal::SetShape (const Shape& shape) {
256
253
if (const Shape* intered_shape_ptr = TryInternShape (shape)) {
257
254
shape_ = intered_shape_ptr;
@@ -273,15 +270,15 @@ void Literal::SetShape(const Shape& shape) {
273
270
shape_ = std::move (owning_shape_ptr);
274
271
}
275
272
276
- void Literal::SetPiece (const Shape& shape, Piece* piece, bool allocate_arrays,
277
- ArrayValueState leaf_array_value_state) {
273
+ absl::Status Literal::SetPiece (const Shape& shape, Piece* piece,
274
+ bool allocate_arrays,
275
+ ArrayValueState leaf_array_value_state) {
278
276
if (shape.IsTuple ()) {
279
277
for (const Shape& subshape : shape.tuple_shapes ()) {
280
278
Piece child_piece;
281
279
child_piece.set_subshape (&subshape);
282
-
283
- SetPiece (subshape, &child_piece, allocate_arrays, leaf_array_value_state);
284
-
280
+ TF_RETURN_IF_ERROR (SetPiece (subshape, &child_piece, allocate_arrays,
281
+ leaf_array_value_state));
285
282
piece->emplace_back (std::move (child_piece));
286
283
}
287
284
} else if (shape.IsArray ()) {
@@ -292,22 +289,34 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays,
292
289
piece->set_array_value_state (leaf_array_value_state);
293
290
if (leaf_array_value_state == LiteralBase::ArrayValueState::kKnown &&
294
291
allocate_arrays) {
295
- piece->AllocateBuffers ();
292
+ TF_RETURN_IF_ERROR ( piece->AllocateBuffers () );
296
293
}
297
294
}
295
+ return absl::OkStatus ();
298
296
}
299
297
300
- Literal::Literal (const Shape& shape, bool allocate_arrays,
301
- ArrayValueState leaf_array_value_state) {
302
- SetShape (shape);
298
+ absl::StatusOr<Literal> Literal::Make (
299
+ const Shape& shape, const bool allocate_arrays,
300
+ const ArrayValueState leaf_array_value_state) {
301
+ // We cannot use the default constructor because it is implemented in terms
302
+ // of Literal::Make(), leading to infinite recursion.
303
+ Literal literal (UninitializedLiteralTag{});
304
+ literal.SetShape (shape);
303
305
CHECK (leaf_array_value_state != ArrayValueState::kKnown ||
304
- LayoutUtil::HasLayout (*shape_));
305
- root_piece_.set_subshape (shape_.get ());
306
- CHECK (&root_piece_.subshape () == shape_.get ());
306
+ LayoutUtil::HasLayout (*literal. shape_ ));
307
+ literal. root_piece_ .set_subshape (literal. shape_ .get ());
308
+ CHECK (&literal. root_piece_ .subshape () == literal. shape_ .get ());
307
309
308
- SetPiece (*shape_, &root_piece_, allocate_arrays, leaf_array_value_state);
310
+ TF_RETURN_IF_ERROR (literal.SetPiece (*literal.shape_ , &literal.root_piece_ ,
311
+ allocate_arrays, leaf_array_value_state));
312
+ return literal;
309
313
}
310
314
315
+ Literal::Literal (const Shape& shape, bool allocate_arrays,
316
+ ArrayValueState leaf_array_value_state)
317
+ : Literal(Literal::Make(shape, allocate_arrays, leaf_array_value_state)
318
+ .value()) {}
319
+
311
320
Literal::~Literal () { DeallocateBuffers (); }
312
321
313
322
void Literal::DeallocateBuffers () {
@@ -622,16 +631,20 @@ void LiteralBase::Piece::SetDynamicSize(int64_t dim_index, int32_t size) {
622
631
dynamic_size_buffer ()[dim_index] = size;
623
632
}
624
633
625
- void LiteralBase::Piece::AllocateBuffers () {
634
+ absl::Status LiteralBase::Piece::AllocateBuffers () {
626
635
const int64_t bytes = total_bytes_dense ();
627
636
if (bytes > kMaxInlinedBytes ) {
628
637
CHECK_EQ (buffer (), nullptr );
629
638
storage_.Emplace <DenseRep>(
630
639
static_cast <char *>(tsl::port::AlignedMalloc (bytes, kMinimumAlignment )));
631
- CHECK_NE (buffer (), nullptr ) << " Failed to allocate buffer for Literal" ;
640
+ if (buffer () == nullptr ) {
641
+ return absl::ResourceExhaustedError (
642
+ " Failed to allocate buffer for Literal" );
643
+ }
632
644
} else {
633
645
storage_.Emplace <DenseInlinedRep>();
634
646
}
647
+ return absl::OkStatus ();
635
648
}
636
649
637
650
void LiteralBase::Piece::DeallocateBuffers () {
@@ -700,7 +713,7 @@ absl::Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src,
700
713
CHECK (src.array_value_state_ == ArrayValueState::kKnown );
701
714
if (array_value_state_ == ArrayValueState::kUndetermined ||
702
715
array_value_state_ == ArrayValueState::kUnknown ) {
703
- AllocateBuffers ();
716
+ TF_RETURN_IF_ERROR ( AllocateBuffers () );
704
717
}
705
718
array_value_state_ = src.array_value_state_ ;
706
719
}
0 commit comments