Skip to content

Commit 9ebe7e3

Browse files
Automated Code Change
PiperOrigin-RevId: 785633377
1 parent f015a18 commit 9ebe7e3

File tree

20 files changed

+91
-14
lines changed

20 files changed

+91
-14
lines changed

tensorflow/compiler/jit/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ cc_library(
355355
"//tensorflow/core/tpu:tpu_defs",
356356
"@com_google_absl//absl/algorithm:container",
357357
"@com_google_absl//absl/base",
358+
"@com_google_absl//absl/container:flat_hash_map",
358359
"@com_google_absl//absl/log",
359360
"@com_google_absl//absl/log:check",
360361
"@com_google_absl//absl/memory",
@@ -641,6 +642,7 @@ cc_library(
641642
"//tensorflow/core/tfrt/common:async_value_tensor",
642643
"@com_google_absl//absl/algorithm:container",
643644
"@com_google_absl//absl/cleanup",
645+
"@com_google_absl//absl/container:flat_hash_map",
644646
"@com_google_absl//absl/container:flat_hash_set",
645647
"@com_google_absl//absl/status",
646648
"@com_google_absl//absl/types:span",

tensorflow/compiler/jit/kernels/xla_ops.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
604604
done);
605605
OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)),
606606
done);
607-
std::map<int, const Tensor*> resource_var_ptrs;
607+
absl::flat_hash_map<int, const Tensor*> resource_var_ptrs;
608608
for (int i = 0; i < resources.size(); i++) {
609609
resource_var_ptrs[resources[i]] = variable_infos[i].var()->tensor();
610610
}
@@ -928,7 +928,7 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
928928
const xla::HloInputOutputAliasConfig& input_output_alias =
929929
closure.executable()->executable()->module().input_output_alias_config();
930930
absl::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs;
931-
std::map<int, const Tensor*> snapshot_ptrs;
931+
absl::flat_hash_map<int, const Tensor*> snapshot_ptrs;
932932
{
933933
tsl::profiler::TraceMe hlo_module_activity(
934934
[&] {

tensorflow/compiler/jit/xla_compile_on_demand_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License.
2323
#include <utility>
2424
#include <vector>
2525

26+
#include "absl/container/flat_hash_map.h"
2627
#include "absl/log/check.h"
2728
#include "absl/log/log.h"
2829
#include "absl/memory/memory.h"
@@ -130,7 +131,7 @@ absl::Status XlaCompileOnDemandOp::Run(
130131
? platform_info_.xla_device_metadata()->UseMultipleStreams()
131132
: false);
132133

133-
std::map<int, const Tensor*> snapshot_ptrs;
134+
absl::flat_hash_map<int, const Tensor*> snapshot_ptrs;
134135
for (auto& p : variable_args) {
135136
snapshot_ptrs.emplace(p.first,
136137
p.second.has_value() ? &p.second.value() : nullptr);

tensorflow/compiler/jit/xla_launch_util.cc

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424

2525
#include "absl/algorithm/container.h"
2626
#include "absl/cleanup/cleanup.h"
27+
#include "absl/container/flat_hash_map.h"
2728
#include "absl/container/flat_hash_set.h"
2829
#include "absl/status/status.h"
2930
#include "absl/types/span.h"
@@ -127,7 +128,7 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
127128

128129
// Fills in `execution_input` with `buffer` for `index`.
129130
static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input,
130-
xla::ShapeIndex index,
131+
const xla::ShapeIndex& index,
131132
se::DeviceMemoryBase buffer,
132133
bool donate_buffer, int device_ordinal,
133134
se::DeviceMemoryAllocator* allocator) {
@@ -149,12 +150,14 @@ absl::StatusOr<std::vector<xla::ExecutionInput>>
149150
XlaComputationLaunchContext::PopulateInputs(
150151
OpKernelContext* ctx,
151152
const XlaCompiler::CompilationResult* compilation_result,
152-
const std::map<int, const Tensor*>& resource_vars,
153+
const absl::flat_hash_map<int, const Tensor*>& resource_vars,
153154
int missing_ctx_input_prefix,
154155
const xla::HloInputOutputAliasConfig& input_output_alias) {
155156
std::vector<xla::ExecutionInput> arguments;
156157
arguments.reserve(compilation_result->xla_input_shapes.size());
157158

159+
xla::ShapeIndex root_index = {};
160+
158161
for (int i = 0; i < compilation_result->xla_input_shapes.size(); ++i) {
159162
int arg_num = compilation_result->input_mapping[i];
160163
CHECK_GE(arg_num, missing_ctx_input_prefix);
@@ -176,9 +179,8 @@ XlaComputationLaunchContext::PopulateInputs(
176179
? resource_var_it->second
177180
: &(ctx->input(arg_num - missing_ctx_input_prefix));
178181
CHECK(t);
179-
bool donate_buffer =
180-
t->RefCountIsOne() && is_updated_resource_variable &&
181-
input_output_alias.ParameterHasAlias(i, xla::ShapeIndex{});
182+
bool donate_buffer = t->RefCountIsOne() && is_updated_resource_variable &&
183+
input_output_alias.ParameterHasAlias(i, root_index);
182184
VLOG(3) << "Processing input: " << i
183185
<< "; is_resource_variable=" << is_resource_variable
184186
<< "; is_updated_resource_variable=" << is_updated_resource_variable
@@ -196,7 +198,7 @@ XlaComputationLaunchContext::PopulateInputs(
196198
arguments.emplace_back(&device_shape);
197199
xla::ExecutionInput& execution_input = arguments.back();
198200
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
199-
PopulateExecutionInputBuffer(execution_input, xla::ShapeIndex{}, dmem,
201+
PopulateExecutionInputBuffer(execution_input, root_index, dmem,
200202
donate_buffer, device_ordinal_,
201203
xla_allocator_);
202204
}
@@ -222,7 +224,7 @@ static absl::StatusOr<Tensor> GetOrCreateTensorForOutput(
222224
int missing_ctx_input_prefix,
223225
const xla::HloInputOutputAliasConfig& input_output_alias,
224226
absl::Span<const int> input_mapping,
225-
const std::map<int, const Tensor*>& resource_vars_snapshots,
227+
const absl::flat_hash_map<int, const Tensor*>& resource_vars_snapshots,
226228
DataType output_dtype, const TensorShape& output_shape,
227229
Allocator* output_allocator, bool allocate_xla_tensors, se::Stream* stream,
228230
bool use_multiple_streams, std::shared_ptr<se::Event> definition_event) {
@@ -359,7 +361,7 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs(
359361
ScopedShapedBuffer output, int missing_ctx_input_prefix,
360362
absl::Span<VariableInfo> variable_infos,
361363
const xla::HloInputOutputAliasConfig& input_output_alias,
362-
const std::map<int, const Tensor*>& resource_vars) {
364+
const absl::flat_hash_map<int, const Tensor*>& resource_vars) {
363365
se::Stream* stream =
364366
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
365367
Allocator* allocator = ctx->device()->GetAllocator({});

tensorflow/compiler/jit/xla_launch_util.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License.
2323
#include <set>
2424
#include <vector>
2525

26+
#include "absl/container/flat_hash_map.h"
2627
#include "tensorflow/compiler/jit/variable_info.h"
2728
#include "tensorflow/compiler/jit/xla_tensor.h"
2829
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -188,7 +189,7 @@ class XlaComputationLaunchContext {
188189
absl::StatusOr<std::vector<xla::ExecutionInput>> PopulateInputs(
189190
OpKernelContext* ctx,
190191
const XlaCompiler::CompilationResult* compilation_result,
191-
const std::map<int, const Tensor*>& resource_vars,
192+
const absl::flat_hash_map<int, const Tensor*>& resource_vars,
192193
int missing_ctx_input_prefix,
193194
const xla::HloInputOutputAliasConfig& input_output_alias);
194195

@@ -208,7 +209,7 @@ class XlaComputationLaunchContext {
208209
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
209210
absl::Span<VariableInfo> variable_infos,
210211
const xla::HloInputOutputAliasConfig& input_output_alias,
211-
const std::map<int, const Tensor*>& resource_vars);
212+
const absl::flat_hash_map<int, const Tensor*>& resource_vars);
212213

213214
private:
214215
xla::LocalClient* client_;

tensorflow/lite/java/jni/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable")
22

33
package(
44
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
5-
default_visibility = ["//tensorflow/lite:__subpackages__"],
65
licenses = ["notice"],
76
)
87

third_party/xla/third_party/stablehlo/temporary.patch

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,40 @@ diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.cpp b/stablehlo/stablehl
102102
p.printRegion(cond, /*printEntryBlockArgs=*/false);
103103
p << " do ";
104104
p.printRegion(body, /*printEntryBlockArgs=*/false);
105+
diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp
106+
--- stablehlo/stablehlo/dialect/StablehloOps.cpp
107+
+++ stablehlo/stablehlo/dialect/StablehloOps.cpp
108+
@@ -2201,14 +2201,14 @@
109+
locs.reserve(numValues);
110+
for (auto i : inputs) {
111+
auto iType = cast<ShapedType>(i.getType());
112+
- blockArgTypes.push_back(iType.cloneWith(
113+
- llvm::ArrayRef<int64_t>(std::nullopt), iType.getElementType()));
114+
+ blockArgTypes.push_back(
115+
+ iType.cloneWith(llvm::ArrayRef<int64_t>(), iType.getElementType()));
116+
locs.push_back(i.getLoc());
117+
}
118+
for (auto i : init_values) {
119+
auto iType = cast<ShapedType>(i.getType());
120+
- blockArgTypes.push_back(iType.cloneWith(
121+
- llvm::ArrayRef<int64_t>(std::nullopt), iType.getElementType()));
122+
+ blockArgTypes.push_back(
123+
+ iType.cloneWith(llvm::ArrayRef<int64_t>(), iType.getElementType()));
124+
locs.push_back(i.getLoc());
125+
}
126+
105127
diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp
106128
--- stablehlo/stablehlo/dialect/TypeInference.cpp
107129
+++ stablehlo/stablehlo/dialect/TypeInference.cpp
130+
@@ -1147,7 +1147,7 @@
131+
*paddingOrErr,
132+
/*lhsDilation=*/baseDilations.value_or(SmallVector<int64_t, 0>{}),
133+
/*rhsDilation=*/windowDilations.value_or(SmallVector<int64_t, 0>{}),
134+
- /*windowReversal=*/std::nullopt, location);
135+
+ /*windowReversal=*/{}, location);
136+
if (failed(windowOrErr)) return failure();
137+
138+
windowDims.append(windowDimensions.begin(), windowDimensions.end());
108139
@@ -2248,6 +2248,22 @@
109140
return success();
110141
}

third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ class NanoValue : public llvm::RTTIExtends<Self, Base> {
120120
// Called by subclasses to get access to client() without having to cast.
121121
NanoIfrtClient* nano_client() const { return client_; }
122122

123+
ifrt::UserContextRef user_context() const override { return {}; }
124+
123125
// All nano values are immediately ready.
124126
ifrt::Future<> GetReadyFuture() const override { return Ready(); }
125127

@@ -861,6 +863,8 @@ class NanoExecutable final
861863
return absl::UnimplementedError("Serialize is not implemented.");
862864
}
863865

866+
ifrt::UserContextRef user_context() const override { return {}; }
867+
864868
ifrt::Future<> GetReadyFuture() const override { return Ready(); }
865869

866870
int num_devices() const override { return 1; }

third_party/xla/xla/python/ifrt/executable.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ limitations under the License.
3838
#include "xla/python/ifrt/future.h"
3939
#include "xla/python/ifrt/serdes_default_version_accessor.h"
4040
#include "xla/python/ifrt/serdes_version.h"
41+
#include "xla/python/ifrt/user_context.h"
4142
#include "xla/xla_data.pb.h"
4243

4344
namespace xla {
@@ -161,6 +162,11 @@ class LoadedExecutable
161162
// serialized executable is implementation-specific.
162163
virtual absl::StatusOr<std::string> Serialize() const = 0;
163164

165+
// Returns the user context associated with the creation of this executable.
166+
// May be `nullptr` if the user context is unset or the runtime does not
167+
// support it.
168+
virtual UserContextRef user_context() const = 0;
169+
164170
// Returns a future that becomes ready when the executable is ready to be
165171
// used for execution.
166172
//

third_party/xla/xla/python/ifrt/mock.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ using ::testing::_;
6060
// LINT.IfChange(MockArrayDelegation)
6161
MockArray::MockArray(xla::ifrt::ArrayRef delegated)
6262
: delegated_(std::move(delegated)) {
63+
ON_CALL(*this, user_context).WillByDefault([this]() {
64+
return delegated_->user_context();
65+
});
6366
ON_CALL(*this, GetReadyFuture).WillByDefault([this]() {
6467
return delegated_->GetReadyFuture();
6568
});

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy