Skip to content

Commit 7816210

Browse files
shawnwang18tensorflower-gardener
authored andcommitted
PR #29204: [XLA:GPU] Remove Stream Id from command buffer
Imported from GitHub PR openxla/xla#29204 Command buffer now uses DAG to specify dependency, so stream id is totally not used, this PR remove the Stream ID from command buffer implementation. Copybara import of the project: -- 0592cf18da946ab0f18bad6a16399c263a191dd6 by Shawn Wang <shawnw@nvidia.com>: Remove Stream Id from command buffer Merging this change closes #29204 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#29204 from shawnwang18:shawnw/remove_stream_from_command_buffer 0592cf18da946ab0f18bad6a16399c263a191dd6 PiperOrigin-RevId: 786128597
1 parent 61f818d commit 7816210

File tree

9 files changed

+193
-331
lines changed

9 files changed

+193
-331
lines changed

third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc

Lines changed: 67 additions & 103 deletions
Large diffs are not rendered by default.

third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h

Lines changed: 28 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ namespace xla::gpu {
8989
V(kBarrierCmd, "BarrierCmd") \
9090
V(kCollectiveCmd, "CollectiveCmd") \
9191
V(kAllReduceCmd, "AllReduceCmd") \
92-
V(kReduceScatter, "ReduceScatterCmd") \
93-
V(kAllToAll, "AllToAllCmd") \
92+
V(kReduceScatterCmd, "ReduceScatterCmd") \
93+
V(kAllToAllCmd, "AllToAllCmd") \
9494
V(kAllGatherCmd, "AllGatherCmd") \
9595
V(kCollectiveBroadcastCmd, "CollectiveBroadcastCmd") \
9696
V(kDynamicSliceFusionCmd, "DynamicSliceFusionCmd") \
@@ -124,12 +124,9 @@ using ResourceUseVector = absl::InlinedVector<ResourceUse, 1>;
124124

125125
class CommandBufferCmd {
126126
public:
127-
CommandBufferCmd(CommandBufferCmdType cmd_type,
128-
ExecutionStreamId execution_stream_id,
129-
ResourceUseVector resources,
127+
CommandBufferCmd(CommandBufferCmdType cmd_type, ResourceUseVector resources,
130128
se::StreamPriority priority = se::StreamPriority::Default)
131129
: cmd_type_(cmd_type),
132-
execution_stream_id_(execution_stream_id),
133130
resources_(std::move(resources)),
134131
priority_(priority) {}
135132

@@ -309,12 +306,9 @@ class CommandBufferCmd {
309306
return CommandBufferCmdString(cmd_type_);
310307
}
311308

312-
ExecutionStreamId execution_stream_id() const { return execution_stream_id_; }
313-
314309
private:
315310
std::string profile_annotation_;
316311
CommandBufferCmdType cmd_type_;
317-
ExecutionStreamId execution_stream_id_;
318312
ResourceUseVector resources_;
319313

320314
// Command priority, currently only support default, lowest and highest
@@ -521,7 +515,6 @@ class TracedCommandBuffer : public CommandBufferCmd::State {
521515
class TracedCommandBufferCmd : public CommandBufferCmd {
522516
protected:
523517
explicit TracedCommandBufferCmd(CommandBufferCmdType cmd_type,
524-
ExecutionStreamId execution_stream_id,
525518
ResourceUseVector resources = {});
526519

527520
// Creates a command buffer by calling a user-provided `trace` function and
@@ -540,8 +533,7 @@ class TracedCommandBufferCmd : public CommandBufferCmd {
540533

541534
class EmptyCmd : public CommandBufferCmd {
542535
public:
543-
explicit EmptyCmd(ExecutionStreamId execution_stream_id,
544-
ResourceUseVector resources = {});
536+
explicit EmptyCmd(ResourceUseVector resources = {});
545537

546538
absl::StatusOr<const se::CommandBuffer::Command*> Record(
547539
const Thunk::ExecuteParams& execute_params,
@@ -559,8 +551,7 @@ class ComputationIdCmd : public CommandBufferCmd {
559551
public:
560552
enum class Kind { kReplica, kPartition };
561553

562-
ComputationIdCmd(ExecutionStreamId execution_stream_id,
563-
BufferAllocation::Slice dest, Kind kind,
554+
ComputationIdCmd(BufferAllocation::Slice dest, Kind kind,
564555
ResourceUseVector resources = {});
565556

566557
absl::StatusOr<const se::CommandBuffer::Command*> Record(
@@ -581,7 +572,7 @@ class ComputationIdCmd : public CommandBufferCmd {
581572

582573
class LaunchCmd : public CommandBufferCmd {
583574
public:
584-
LaunchCmd(ExecutionStreamId execution_stream_id, std::string kernel_name,
575+
LaunchCmd(std::string kernel_name,
585576
absl::Span<const BufferAllocation::Slice> args,
586577
absl::Span<const BufferUse::MemoryAccess> args_access,
587578
LaunchDimensions dims, int64_t shmem_bytes,
@@ -617,8 +608,7 @@ class LaunchCmd : public CommandBufferCmd {
617608

618609
class CustomKernelLaunchCmd : public CommandBufferCmd {
619610
public:
620-
CustomKernelLaunchCmd(ExecutionStreamId execution_stream_id,
621-
absl::Span<const BufferAllocation::Slice> args,
611+
CustomKernelLaunchCmd(absl::Span<const BufferAllocation::Slice> args,
622612
absl::Span<const BufferUse::MemoryAccess> args_access,
623613
CustomKernel custom_kernel,
624614
ResourceUseVector resources = {});
@@ -651,8 +641,7 @@ class CustomKernelLaunchCmd : public CommandBufferCmd {
651641

652642
class MemcpyDeviceToDeviceCmd : public CommandBufferCmd {
653643
public:
654-
MemcpyDeviceToDeviceCmd(ExecutionStreamId execution_stream_id,
655-
BufferAllocation::Slice dst,
644+
MemcpyDeviceToDeviceCmd(BufferAllocation::Slice dst,
656645
BufferAllocation::Slice src, int64_t num_bytes,
657646
ResourceUseVector resources = {});
658647

@@ -675,8 +664,7 @@ class MemcpyDeviceToDeviceCmd : public CommandBufferCmd {
675664

676665
class MemzeroCmd : public CommandBufferCmd {
677666
public:
678-
MemzeroCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice dst,
679-
ResourceUseVector resources = {});
667+
MemzeroCmd(BufferAllocation::Slice dst, ResourceUseVector resources = {});
680668

681669
absl::StatusOr<const se::CommandBuffer::Command*> Record(
682670
const Thunk::ExecuteParams& execute_params,
@@ -695,8 +683,7 @@ class MemzeroCmd : public CommandBufferCmd {
695683

696684
class Memset32Cmd : public CommandBufferCmd {
697685
public:
698-
Memset32Cmd(ExecutionStreamId execution_stream_id,
699-
BufferAllocation::Slice dst, uint32_t bit_pattern,
686+
Memset32Cmd(BufferAllocation::Slice dst, uint32_t bit_pattern,
700687
ResourceUseVector resources = {});
701688

702689
absl::StatusOr<const se::CommandBuffer::Command*> Record(
@@ -717,8 +704,8 @@ class Memset32Cmd : public CommandBufferCmd {
717704

718705
class CaseCmd : public CommandBufferCmd {
719706
public:
720-
CaseCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice index,
721-
bool index_is_bool, std::vector<CommandBufferCmdExecutor> branches,
707+
CaseCmd(BufferAllocation::Slice index, bool index_is_bool,
708+
std::vector<CommandBufferCmdExecutor> branches,
722709
ResourceUseVector resources = {});
723710

724711
absl::Status Initialize(const Thunk::InitializeParams& params,
@@ -747,8 +734,7 @@ class CaseCmd : public CommandBufferCmd {
747734

748735
class WhileCmd : public CommandBufferCmd {
749736
public:
750-
WhileCmd(ExecutionStreamId execution_stream_id, BufferAllocation::Slice pred,
751-
CommandBufferCmdExecutor cond_commands,
737+
WhileCmd(BufferAllocation::Slice pred, CommandBufferCmdExecutor cond_commands,
752738
CommandBufferCmdExecutor body_commands,
753739
ResourceUseVector resources = {});
754740

@@ -778,8 +764,7 @@ class WhileCmd : public CommandBufferCmd {
778764

779765
class GemmCmd : public TracedCommandBufferCmd {
780766
public:
781-
GemmCmd(ExecutionStreamId execution_stream_id, GemmConfig config,
782-
const BufferAllocation::Slice& lhs_buffer,
767+
GemmCmd(GemmConfig config, const BufferAllocation::Slice& lhs_buffer,
783768
const BufferAllocation::Slice& rhs_buffer,
784769
const BufferAllocation::Slice& output_buffer,
785770
const BufferAllocation::Slice& workspace, bool deterministic,
@@ -813,8 +798,7 @@ class GemmCmd : public TracedCommandBufferCmd {
813798

814799
class CublasLtCmd : public TracedCommandBufferCmd, public CublasLtMatmulThunk {
815800
public:
816-
CublasLtCmd(ExecutionStreamId execution_stream_id,
817-
const CublasLtMatmulThunk& matmul_thunk,
801+
CublasLtCmd(const CublasLtMatmulThunk& matmul_thunk,
818802
ResourceUseVector resources = {});
819803

820804
absl::Status Initialize(const Thunk::InitializeParams& params,
@@ -841,8 +825,7 @@ class CublasLtCmd : public TracedCommandBufferCmd, public CublasLtMatmulThunk {
841825

842826
class CuDnnCmd : public TracedCommandBufferCmd {
843827
public:
844-
CuDnnCmd(ExecutionStreamId execution_stream_id,
845-
absl::Span<const BufferAllocation::Slice> args,
828+
CuDnnCmd(absl::Span<const BufferAllocation::Slice> args,
846829
std::shared_ptr<se::dnn::LazyDnnGraph> graph,
847830
ResourceUseVector resources = {});
848831

@@ -875,28 +858,26 @@ class CustomCallCmd : public CommandBufferCmd {
875858

876859
// This is a legacy custom call API that is discouraged, and will be
877860
// deprecated once XLA:FFI mechanism is ready.
878-
CustomCallCmd(ExecutionStreamId execution_stream_id, std::string target_name,
879-
CustomCallTarget call_target,
861+
CustomCallCmd(std::string target_name, CustomCallTarget call_target,
880862
std::vector<std::optional<Slice>> operands,
881863
std::vector<std::optional<Slice>> results,
882864
absl::string_view opaque, ResourceUseVector resources = {})
883865
: CommandBufferCmd(CommandBufferCmdType::kCustomCallCmd,
884-
execution_stream_id, resources),
866+
std::move(resources)),
885867
target_name_(std::move(target_name)),
886868
call_target_(std::move(call_target)),
887869
opaque_(opaque),
888870
operands_(std::move(operands)),
889871
results_(std::move(results)) {}
890872

891-
CustomCallCmd(ExecutionStreamId execution_stream_id, std::string target_name,
892-
XLA_FFI_Handler* handler,
873+
CustomCallCmd(std::string target_name, XLA_FFI_Handler* handler,
893874
std::vector<std::optional<Slice>> operands,
894875
std::vector<std::optional<Slice>> results,
895876
ffi::CallFrame call_frame,
896877
const HloComputation* called_computation,
897878
ResourceUseVector resources = {})
898879
: CommandBufferCmd(CommandBufferCmdType::kCustomCallCmd,
899-
execution_stream_id, resources),
880+
std::move(resources)),
900881
target_name_(std::move(target_name)),
901882
handler_(handler),
902883
call_frame_(std::move(call_frame)),
@@ -955,9 +936,7 @@ class CustomCallCmd : public CommandBufferCmd {
955936

956937
class CollectiveCmd : public CommandBufferCmd {
957938
public:
958-
CollectiveCmd(CommandBufferCmdType cmd_type,
959-
ExecutionStreamId execution_stream_id,
960-
ExecutionStreamId async_from_stream_id, CollectiveConfig config,
939+
CollectiveCmd(CommandBufferCmdType cmd_type, CollectiveConfig config,
961940
ResourceUseVector resources = {});
962941

963942
absl::Status Prepare(
@@ -974,27 +953,12 @@ class CollectiveCmd : public CommandBufferCmd {
974953
se::CommandBuffer* command_buffer,
975954
absl::FunctionRef<absl::Status(se::Stream*)> trace);
976955

977-
virtual AsyncStreamKind GetAsyncStreamKind() = 0;
978-
virtual CollectiveStreamId GetAsyncStreamId() = 0;
979-
980-
bool IsAsync() const {
981-
return async_from_stream_id_ != execution_stream_id();
982-
}
983-
984-
CollectiveStreamId nccl_stream_id() {
985-
return xla::gpu::GetCollectiveStreamId(IsAsync(), GetAsyncStreamId(),
986-
GetAsyncStreamKind());
987-
}
988-
989-
ExecutionStreamId async_from_stream_id() const {
990-
return async_from_stream_id_;
991-
}
956+
bool IsAsync() const { return false; }
992957

993958
protected:
994959
const CollectiveConfig& config() const { return config_; }
995960

996961
private:
997-
ExecutionStreamId async_from_stream_id_;
998962
CollectiveConfig config_;
999963
};
1000964

@@ -1004,9 +968,7 @@ class CollectiveCmd : public CommandBufferCmd {
1004968

1005969
class AllReduceCmd : public CollectiveCmd {
1006970
public:
1007-
AllReduceCmd(ExecutionStreamId execution_stream_id,
1008-
ExecutionStreamId async_from_stream_id, CollectiveConfig config,
1009-
ReductionKind reduction_kind,
971+
AllReduceCmd(CollectiveConfig config, ReductionKind reduction_kind,
1010972
absl::Span<const CollectiveThunk::Buffer> buffers,
1011973
ResourceUseVector resources = {});
1012974

@@ -1017,13 +979,6 @@ class AllReduceCmd : public CollectiveCmd {
1017979

1018980
BufferUseVector buffers() const override;
1019981

1020-
AsyncStreamKind GetAsyncStreamKind() override {
1021-
return AsyncStreamKind::kCollective;
1022-
};
1023-
CollectiveStreamId GetAsyncStreamId() override {
1024-
return CollectiveStreamId(1);
1025-
};
1026-
1027982
private:
1028983
ReductionKind reduction_kind_;
1029984
std::vector<CollectiveThunk::Buffer> buffers_;
@@ -1035,9 +990,7 @@ class AllReduceCmd : public CollectiveCmd {
1035990

1036991
class ReduceScatterCmd : public CollectiveCmd {
1037992
public:
1038-
ReduceScatterCmd(ExecutionStreamId execution_stream_id,
1039-
ExecutionStreamId async_from_stream_id,
1040-
CollectiveConfig config, ReductionKind reduction_kind,
993+
ReduceScatterCmd(CollectiveConfig config, ReductionKind reduction_kind,
1041994
absl::Span<const CollectiveThunk::Buffer> buffers,
1042995
ResourceUseVector resources = {});
1043996

@@ -1048,13 +1001,6 @@ class ReduceScatterCmd : public CollectiveCmd {
10481001

10491002
BufferUseVector buffers() const override;
10501003

1051-
AsyncStreamKind GetAsyncStreamKind() override {
1052-
return AsyncStreamKind::kCollective;
1053-
};
1054-
CollectiveStreamId GetAsyncStreamId() override {
1055-
return CollectiveStreamId(1);
1056-
};
1057-
10581004
private:
10591005
ReductionKind reduction_kind_;
10601006
std::vector<CollectiveThunk::Buffer> buffers_;
@@ -1066,9 +1012,7 @@ class ReduceScatterCmd : public CollectiveCmd {
10661012

10671013
class AllToAllCmd : public CollectiveCmd {
10681014
public:
1069-
AllToAllCmd(ExecutionStreamId execution_stream_id,
1070-
ExecutionStreamId async_from_stream_id, CollectiveConfig config,
1071-
bool has_split_dimension,
1015+
AllToAllCmd(CollectiveConfig config, bool has_split_dimension,
10721016
absl::Span<const CollectiveThunk::Buffer> buffers,
10731017
ResourceUseVector resources = {});
10741018

@@ -1079,13 +1023,6 @@ class AllToAllCmd : public CollectiveCmd {
10791023

10801024
BufferUseVector buffers() const override;
10811025

1082-
AsyncStreamKind GetAsyncStreamKind() override {
1083-
return AsyncStreamKind::kCollective;
1084-
};
1085-
CollectiveStreamId GetAsyncStreamId() override {
1086-
return CollectiveStreamId(1);
1087-
};
1088-
10891026
private:
10901027
bool has_split_dimension_;
10911028
std::vector<CollectiveThunk::Buffer> buffers_;
@@ -1097,8 +1034,7 @@ class AllToAllCmd : public CollectiveCmd {
10971034

10981035
class AllGatherCmd : public CollectiveCmd {
10991036
public:
1100-
AllGatherCmd(ExecutionStreamId execution_stream_id,
1101-
ExecutionStreamId async_from_stream_id, CollectiveConfig config,
1037+
AllGatherCmd(CollectiveConfig config,
11021038
absl::Span<const CollectiveThunk::Buffer> buffers,
11031039
ResourceUseVector resources = {});
11041040

@@ -1109,13 +1045,6 @@ class AllGatherCmd : public CollectiveCmd {
11091045

11101046
BufferUseVector buffers() const override;
11111047

1112-
AsyncStreamKind GetAsyncStreamKind() override {
1113-
return AsyncStreamKind::kCollective;
1114-
};
1115-
CollectiveStreamId GetAsyncStreamId() override {
1116-
return CollectiveStreamId(1);
1117-
};
1118-
11191048
private:
11201049
std::vector<CollectiveThunk::Buffer> buffers_;
11211050
};
@@ -1126,9 +1055,7 @@ class AllGatherCmd : public CollectiveCmd {
11261055

11271056
class CollectiveBroadcastCmd : public CollectiveCmd {
11281057
public:
1129-
CollectiveBroadcastCmd(ExecutionStreamId execution_stream_id,
1130-
ExecutionStreamId async_from_stream_id,
1131-
CollectiveConfig config,
1058+
CollectiveBroadcastCmd(CollectiveConfig config,
11321059
absl::Span<const CollectiveThunk::Buffer> buffers,
11331060
ResourceUseVector resources = {});
11341061

@@ -1150,7 +1077,6 @@ class CollectiveBroadcastCmd : public CollectiveCmd {
11501077
class DynamicSliceFusionCmd : public CommandBufferCmd {
11511078
public:
11521079
DynamicSliceFusionCmd(
1153-
ExecutionStreamId execution_stream_id,
11541080
CommandBufferCmdExecutor embedded_commands,
11551081
std::vector<std::optional<BufferAllocation::Slice>> arguments,
11561082
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations_,
@@ -1210,8 +1136,7 @@ class DynamicSliceFusionCmd : public CommandBufferCmd {
12101136
// buffer to another, it is only supported for static slice.
12111137
class DynamicSliceCopyFusionCmd : public CommandBufferCmd {
12121138
public:
1213-
DynamicSliceCopyFusionCmd(ExecutionStreamId execution_stream_id,
1214-
const BufferAllocation::Slice& source_buffer,
1139+
DynamicSliceCopyFusionCmd(const BufferAllocation::Slice& source_buffer,
12151140
const BufferAllocation::Slice& destination_buffer,
12161141
uint64_t mem_size,
12171142
DynamicMemcpyThunk::Offsets offsets,

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy