@@ -89,8 +89,8 @@ namespace xla::gpu {
89
89
V (kBarrierCmd , " BarrierCmd" ) \
90
90
V (kCollectiveCmd , " CollectiveCmd" ) \
91
91
V (kAllReduceCmd , " AllReduceCmd" ) \
92
- V (kReduceScatter , " ReduceScatterCmd" ) \
93
- V (kAllToAll , " AllToAllCmd" ) \
92
+ V (kReduceScatterCmd , " ReduceScatterCmd" ) \
93
+ V (kAllToAllCmd , " AllToAllCmd" ) \
94
94
V (kAllGatherCmd , " AllGatherCmd" ) \
95
95
V (kCollectiveBroadcastCmd , " CollectiveBroadcastCmd" ) \
96
96
V (kDynamicSliceFusionCmd , " DynamicSliceFusionCmd" ) \
@@ -124,12 +124,9 @@ using ResourceUseVector = absl::InlinedVector<ResourceUse, 1>;
124
124
125
125
class CommandBufferCmd {
126
126
public:
127
- CommandBufferCmd (CommandBufferCmdType cmd_type,
128
- ExecutionStreamId execution_stream_id,
129
- ResourceUseVector resources,
127
+ CommandBufferCmd (CommandBufferCmdType cmd_type, ResourceUseVector resources,
130
128
se::StreamPriority priority = se::StreamPriority::Default)
131
129
: cmd_type_(cmd_type),
132
- execution_stream_id_ (execution_stream_id),
133
130
resources_ (std::move(resources)),
134
131
priority_(priority) {}
135
132
@@ -309,12 +306,9 @@ class CommandBufferCmd {
309
306
return CommandBufferCmdString (cmd_type_);
310
307
}
311
308
312
- ExecutionStreamId execution_stream_id () const { return execution_stream_id_; }
313
-
314
309
private:
315
310
std::string profile_annotation_;
316
311
CommandBufferCmdType cmd_type_;
317
- ExecutionStreamId execution_stream_id_;
318
312
ResourceUseVector resources_;
319
313
320
314
// Command priority, currently only support default, lowest and highest
@@ -521,7 +515,6 @@ class TracedCommandBuffer : public CommandBufferCmd::State {
521
515
class TracedCommandBufferCmd : public CommandBufferCmd {
522
516
protected:
523
517
explicit TracedCommandBufferCmd (CommandBufferCmdType cmd_type,
524
- ExecutionStreamId execution_stream_id,
525
518
ResourceUseVector resources = {});
526
519
527
520
// Creates a command buffer by calling a user-provided `trace` function and
@@ -540,8 +533,7 @@ class TracedCommandBufferCmd : public CommandBufferCmd {
540
533
541
534
class EmptyCmd : public CommandBufferCmd {
542
535
public:
543
- explicit EmptyCmd (ExecutionStreamId execution_stream_id,
544
- ResourceUseVector resources = {});
536
+ explicit EmptyCmd (ResourceUseVector resources = {});
545
537
546
538
absl::StatusOr<const se::CommandBuffer::Command*> Record (
547
539
const Thunk::ExecuteParams& execute_params,
@@ -559,8 +551,7 @@ class ComputationIdCmd : public CommandBufferCmd {
559
551
public:
560
552
enum class Kind { kReplica , kPartition };
561
553
562
- ComputationIdCmd (ExecutionStreamId execution_stream_id,
563
- BufferAllocation::Slice dest, Kind kind,
554
+ ComputationIdCmd (BufferAllocation::Slice dest, Kind kind,
564
555
ResourceUseVector resources = {});
565
556
566
557
absl::StatusOr<const se::CommandBuffer::Command*> Record (
@@ -581,7 +572,7 @@ class ComputationIdCmd : public CommandBufferCmd {
581
572
582
573
class LaunchCmd : public CommandBufferCmd {
583
574
public:
584
- LaunchCmd (ExecutionStreamId execution_stream_id, std::string kernel_name,
575
+ LaunchCmd (std::string kernel_name,
585
576
absl::Span<const BufferAllocation::Slice> args,
586
577
absl::Span<const BufferUse::MemoryAccess> args_access,
587
578
LaunchDimensions dims, int64_t shmem_bytes,
@@ -617,8 +608,7 @@ class LaunchCmd : public CommandBufferCmd {
617
608
618
609
class CustomKernelLaunchCmd : public CommandBufferCmd {
619
610
public:
620
- CustomKernelLaunchCmd (ExecutionStreamId execution_stream_id,
621
- absl::Span<const BufferAllocation::Slice> args,
611
+ CustomKernelLaunchCmd (absl::Span<const BufferAllocation::Slice> args,
622
612
absl::Span<const BufferUse::MemoryAccess> args_access,
623
613
CustomKernel custom_kernel,
624
614
ResourceUseVector resources = {});
@@ -651,8 +641,7 @@ class CustomKernelLaunchCmd : public CommandBufferCmd {
651
641
652
642
class MemcpyDeviceToDeviceCmd : public CommandBufferCmd {
653
643
public:
654
- MemcpyDeviceToDeviceCmd (ExecutionStreamId execution_stream_id,
655
- BufferAllocation::Slice dst,
644
+ MemcpyDeviceToDeviceCmd (BufferAllocation::Slice dst,
656
645
BufferAllocation::Slice src, int64_t num_bytes,
657
646
ResourceUseVector resources = {});
658
647
@@ -675,8 +664,7 @@ class MemcpyDeviceToDeviceCmd : public CommandBufferCmd {
675
664
676
665
class MemzeroCmd : public CommandBufferCmd {
677
666
public:
678
- MemzeroCmd (ExecutionStreamId execution_stream_id, BufferAllocation::Slice dst,
679
- ResourceUseVector resources = {});
667
+ MemzeroCmd (BufferAllocation::Slice dst, ResourceUseVector resources = {});
680
668
681
669
absl::StatusOr<const se::CommandBuffer::Command*> Record (
682
670
const Thunk::ExecuteParams& execute_params,
@@ -695,8 +683,7 @@ class MemzeroCmd : public CommandBufferCmd {
695
683
696
684
class Memset32Cmd : public CommandBufferCmd {
697
685
public:
698
- Memset32Cmd (ExecutionStreamId execution_stream_id,
699
- BufferAllocation::Slice dst, uint32_t bit_pattern,
686
+ Memset32Cmd (BufferAllocation::Slice dst, uint32_t bit_pattern,
700
687
ResourceUseVector resources = {});
701
688
702
689
absl::StatusOr<const se::CommandBuffer::Command*> Record (
@@ -717,8 +704,8 @@ class Memset32Cmd : public CommandBufferCmd {
717
704
718
705
class CaseCmd : public CommandBufferCmd {
719
706
public:
720
- CaseCmd (ExecutionStreamId execution_stream_id, BufferAllocation::Slice index,
721
- bool index_is_bool, std::vector<CommandBufferCmdExecutor> branches,
707
+ CaseCmd (BufferAllocation::Slice index, bool index_is_bool ,
708
+ std::vector<CommandBufferCmdExecutor> branches,
722
709
ResourceUseVector resources = {});
723
710
724
711
absl::Status Initialize (const Thunk::InitializeParams& params,
@@ -747,8 +734,7 @@ class CaseCmd : public CommandBufferCmd {
747
734
748
735
class WhileCmd : public CommandBufferCmd {
749
736
public:
750
- WhileCmd (ExecutionStreamId execution_stream_id, BufferAllocation::Slice pred,
751
- CommandBufferCmdExecutor cond_commands,
737
+ WhileCmd (BufferAllocation::Slice pred, CommandBufferCmdExecutor cond_commands,
752
738
CommandBufferCmdExecutor body_commands,
753
739
ResourceUseVector resources = {});
754
740
@@ -778,8 +764,7 @@ class WhileCmd : public CommandBufferCmd {
778
764
779
765
class GemmCmd : public TracedCommandBufferCmd {
780
766
public:
781
- GemmCmd (ExecutionStreamId execution_stream_id, GemmConfig config,
782
- const BufferAllocation::Slice& lhs_buffer,
767
+ GemmCmd (GemmConfig config, const BufferAllocation::Slice& lhs_buffer,
783
768
const BufferAllocation::Slice& rhs_buffer,
784
769
const BufferAllocation::Slice& output_buffer,
785
770
const BufferAllocation::Slice& workspace, bool deterministic,
@@ -813,8 +798,7 @@ class GemmCmd : public TracedCommandBufferCmd {
813
798
814
799
class CublasLtCmd : public TracedCommandBufferCmd , public CublasLtMatmulThunk {
815
800
public:
816
- CublasLtCmd (ExecutionStreamId execution_stream_id,
817
- const CublasLtMatmulThunk& matmul_thunk,
801
+ CublasLtCmd (const CublasLtMatmulThunk& matmul_thunk,
818
802
ResourceUseVector resources = {});
819
803
820
804
absl::Status Initialize (const Thunk::InitializeParams& params,
@@ -841,8 +825,7 @@ class CublasLtCmd : public TracedCommandBufferCmd, public CublasLtMatmulThunk {
841
825
842
826
class CuDnnCmd : public TracedCommandBufferCmd {
843
827
public:
844
- CuDnnCmd (ExecutionStreamId execution_stream_id,
845
- absl::Span<const BufferAllocation::Slice> args,
828
+ CuDnnCmd (absl::Span<const BufferAllocation::Slice> args,
846
829
std::shared_ptr<se::dnn::LazyDnnGraph> graph,
847
830
ResourceUseVector resources = {});
848
831
@@ -875,28 +858,26 @@ class CustomCallCmd : public CommandBufferCmd {
875
858
876
859
// This is a legacy custom call API that is discouraged, and will be
877
860
// deprecated once XLA:FFI mechanism is ready.
878
- CustomCallCmd (ExecutionStreamId execution_stream_id, std::string target_name,
879
- CustomCallTarget call_target,
861
+ CustomCallCmd (std::string target_name, CustomCallTarget call_target,
880
862
std::vector<std::optional<Slice>> operands,
881
863
std::vector<std::optional<Slice>> results,
882
864
absl::string_view opaque, ResourceUseVector resources = {})
883
865
: CommandBufferCmd(CommandBufferCmdType::kCustomCallCmd ,
884
- execution_stream_id, resources),
866
+ std::move ( resources) ),
885
867
target_name_(std::move(target_name)),
886
868
call_target_(std::move(call_target)),
887
869
opaque_(opaque),
888
870
operands_(std::move(operands)),
889
871
results_(std::move(results)) {}
890
872
891
- CustomCallCmd (ExecutionStreamId execution_stream_id, std::string target_name,
892
- XLA_FFI_Handler* handler,
873
+ CustomCallCmd (std::string target_name, XLA_FFI_Handler* handler,
893
874
std::vector<std::optional<Slice>> operands,
894
875
std::vector<std::optional<Slice>> results,
895
876
ffi::CallFrame call_frame,
896
877
const HloComputation* called_computation,
897
878
ResourceUseVector resources = {})
898
879
: CommandBufferCmd(CommandBufferCmdType::kCustomCallCmd ,
899
- execution_stream_id, resources),
880
+ std::move ( resources) ),
900
881
target_name_(std::move(target_name)),
901
882
handler_(handler),
902
883
call_frame_(std::move(call_frame)),
@@ -955,9 +936,7 @@ class CustomCallCmd : public CommandBufferCmd {
955
936
956
937
class CollectiveCmd : public CommandBufferCmd {
957
938
public:
958
- CollectiveCmd (CommandBufferCmdType cmd_type,
959
- ExecutionStreamId execution_stream_id,
960
- ExecutionStreamId async_from_stream_id, CollectiveConfig config,
939
+ CollectiveCmd (CommandBufferCmdType cmd_type, CollectiveConfig config,
961
940
ResourceUseVector resources = {});
962
941
963
942
absl::Status Prepare (
@@ -974,27 +953,12 @@ class CollectiveCmd : public CommandBufferCmd {
974
953
se::CommandBuffer* command_buffer,
975
954
absl::FunctionRef<absl::Status(se::Stream*)> trace);
976
955
977
- virtual AsyncStreamKind GetAsyncStreamKind () = 0;
978
- virtual CollectiveStreamId GetAsyncStreamId () = 0;
979
-
980
- bool IsAsync () const {
981
- return async_from_stream_id_ != execution_stream_id ();
982
- }
983
-
984
- CollectiveStreamId nccl_stream_id () {
985
- return xla::gpu::GetCollectiveStreamId (IsAsync (), GetAsyncStreamId (),
986
- GetAsyncStreamKind ());
987
- }
988
-
989
- ExecutionStreamId async_from_stream_id () const {
990
- return async_from_stream_id_;
991
- }
956
+ bool IsAsync () const { return false ; }
992
957
993
958
protected:
994
959
const CollectiveConfig& config () const { return config_; }
995
960
996
961
private:
997
- ExecutionStreamId async_from_stream_id_;
998
962
CollectiveConfig config_;
999
963
};
1000
964
@@ -1004,9 +968,7 @@ class CollectiveCmd : public CommandBufferCmd {
1004
968
1005
969
class AllReduceCmd : public CollectiveCmd {
1006
970
public:
1007
- AllReduceCmd (ExecutionStreamId execution_stream_id,
1008
- ExecutionStreamId async_from_stream_id, CollectiveConfig config,
1009
- ReductionKind reduction_kind,
971
+ AllReduceCmd (CollectiveConfig config, ReductionKind reduction_kind,
1010
972
absl::Span<const CollectiveThunk::Buffer> buffers,
1011
973
ResourceUseVector resources = {});
1012
974
@@ -1017,13 +979,6 @@ class AllReduceCmd : public CollectiveCmd {
1017
979
1018
980
BufferUseVector buffers () const override ;
1019
981
1020
- AsyncStreamKind GetAsyncStreamKind () override {
1021
- return AsyncStreamKind::kCollective ;
1022
- };
1023
- CollectiveStreamId GetAsyncStreamId () override {
1024
- return CollectiveStreamId (1 );
1025
- };
1026
-
1027
982
private:
1028
983
ReductionKind reduction_kind_;
1029
984
std::vector<CollectiveThunk::Buffer> buffers_;
@@ -1035,9 +990,7 @@ class AllReduceCmd : public CollectiveCmd {
1035
990
1036
991
class ReduceScatterCmd : public CollectiveCmd {
1037
992
public:
1038
- ReduceScatterCmd (ExecutionStreamId execution_stream_id,
1039
- ExecutionStreamId async_from_stream_id,
1040
- CollectiveConfig config, ReductionKind reduction_kind,
993
+ ReduceScatterCmd (CollectiveConfig config, ReductionKind reduction_kind,
1041
994
absl::Span<const CollectiveThunk::Buffer> buffers,
1042
995
ResourceUseVector resources = {});
1043
996
@@ -1048,13 +1001,6 @@ class ReduceScatterCmd : public CollectiveCmd {
1048
1001
1049
1002
BufferUseVector buffers () const override ;
1050
1003
1051
- AsyncStreamKind GetAsyncStreamKind () override {
1052
- return AsyncStreamKind::kCollective ;
1053
- };
1054
- CollectiveStreamId GetAsyncStreamId () override {
1055
- return CollectiveStreamId (1 );
1056
- };
1057
-
1058
1004
private:
1059
1005
ReductionKind reduction_kind_;
1060
1006
std::vector<CollectiveThunk::Buffer> buffers_;
@@ -1066,9 +1012,7 @@ class ReduceScatterCmd : public CollectiveCmd {
1066
1012
1067
1013
class AllToAllCmd : public CollectiveCmd {
1068
1014
public:
1069
- AllToAllCmd (ExecutionStreamId execution_stream_id,
1070
- ExecutionStreamId async_from_stream_id, CollectiveConfig config,
1071
- bool has_split_dimension,
1015
+ AllToAllCmd (CollectiveConfig config, bool has_split_dimension,
1072
1016
absl::Span<const CollectiveThunk::Buffer> buffers,
1073
1017
ResourceUseVector resources = {});
1074
1018
@@ -1079,13 +1023,6 @@ class AllToAllCmd : public CollectiveCmd {
1079
1023
1080
1024
BufferUseVector buffers () const override ;
1081
1025
1082
- AsyncStreamKind GetAsyncStreamKind () override {
1083
- return AsyncStreamKind::kCollective ;
1084
- };
1085
- CollectiveStreamId GetAsyncStreamId () override {
1086
- return CollectiveStreamId (1 );
1087
- };
1088
-
1089
1026
private:
1090
1027
bool has_split_dimension_;
1091
1028
std::vector<CollectiveThunk::Buffer> buffers_;
@@ -1097,8 +1034,7 @@ class AllToAllCmd : public CollectiveCmd {
1097
1034
1098
1035
class AllGatherCmd : public CollectiveCmd {
1099
1036
public:
1100
- AllGatherCmd (ExecutionStreamId execution_stream_id,
1101
- ExecutionStreamId async_from_stream_id, CollectiveConfig config,
1037
+ AllGatherCmd (CollectiveConfig config,
1102
1038
absl::Span<const CollectiveThunk::Buffer> buffers,
1103
1039
ResourceUseVector resources = {});
1104
1040
@@ -1109,13 +1045,6 @@ class AllGatherCmd : public CollectiveCmd {
1109
1045
1110
1046
BufferUseVector buffers () const override ;
1111
1047
1112
- AsyncStreamKind GetAsyncStreamKind () override {
1113
- return AsyncStreamKind::kCollective ;
1114
- };
1115
- CollectiveStreamId GetAsyncStreamId () override {
1116
- return CollectiveStreamId (1 );
1117
- };
1118
-
1119
1048
private:
1120
1049
std::vector<CollectiveThunk::Buffer> buffers_;
1121
1050
};
@@ -1126,9 +1055,7 @@ class AllGatherCmd : public CollectiveCmd {
1126
1055
1127
1056
class CollectiveBroadcastCmd : public CollectiveCmd {
1128
1057
public:
1129
- CollectiveBroadcastCmd (ExecutionStreamId execution_stream_id,
1130
- ExecutionStreamId async_from_stream_id,
1131
- CollectiveConfig config,
1058
+ CollectiveBroadcastCmd (CollectiveConfig config,
1132
1059
absl::Span<const CollectiveThunk::Buffer> buffers,
1133
1060
ResourceUseVector resources = {});
1134
1061
@@ -1150,7 +1077,6 @@ class CollectiveBroadcastCmd : public CollectiveCmd {
1150
1077
class DynamicSliceFusionCmd : public CommandBufferCmd {
1151
1078
public:
1152
1079
DynamicSliceFusionCmd (
1153
- ExecutionStreamId execution_stream_id,
1154
1080
CommandBufferCmdExecutor embedded_commands,
1155
1081
std::vector<std::optional<BufferAllocation::Slice>> arguments,
1156
1082
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations_,
@@ -1210,8 +1136,7 @@ class DynamicSliceFusionCmd : public CommandBufferCmd {
1210
1136
// buffer to another, it is only supported for static slice.
1211
1137
class DynamicSliceCopyFusionCmd : public CommandBufferCmd {
1212
1138
public:
1213
- DynamicSliceCopyFusionCmd (ExecutionStreamId execution_stream_id,
1214
- const BufferAllocation::Slice& source_buffer,
1139
+ DynamicSliceCopyFusionCmd (const BufferAllocation::Slice& source_buffer,
1215
1140
const BufferAllocation::Slice& destination_buffer,
1216
1141
uint64_t mem_size,
1217
1142
DynamicMemcpyThunk::Offsets offsets,
0 commit comments