thread_group_tensor_slice_transfer_v7r2.hpp Source File

thread_group_tensor_slice_transfer_v7r2.hpp Source File#

Composable Kernel: thread_group_tensor_slice_transfer_v7r2.hpp Source File
thread_group_tensor_slice_transfer_v7r2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
13namespace ck {
14
15// Thread-group level multi-source, multi-destination tensor slice data movement
16// Assume:
17// 1. All sources and destinations are DynamicBuffer
18// 2. Same VectorDim and ScalerPerVector for all sources and destinations
19// 3. DstInMemOps are per destination tensor
20// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
21// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
22//
23// Does following things to avoid scratch memory issue
24// 1. Pass tensor descritpors by reference (or tuple of references)
25// 2. Does not keep reference to tensor descriptor
26// 3. Does not construct new tensor coordinate when call Run()
27template <typename ThreadGroup,
28 typename SrcDatas,
29 typename DstDatas,
30 typename SrcDescs,
31 typename DstDescs,
32 typename ElementwiseOperation,
33 typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
34 typename SliceLengths,
35 typename ThreadClusterLengths,
36 typename ThreadClusterArrangeOrder,
37 typename SrcDimAccessOrder,
38 typename DstDimAccessOrder,
39 index_t SrcVectorDim,
40 index_t DstVectorDim,
41 index_t SrcScalarPerVector,
42 index_t DstScalarPerVector,
43 typename ThreadTransferSrcResetCoordinateAfterRunFlags,
44 typename ThreadTransferDstResetCoordinateAfterRunFlags,
45 index_t NumThreadScratch = 1>
47{
48 static constexpr index_t nDim =
50
53
55
56 static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
57
59 const SrcDescs& src_descs,
60 const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
61 const DstDescs& dst_descs,
62 const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
63 const ElementwiseOperation& element_op)
64 : threadwise_transfer_(src_descs,
66 dst_descs,
68 element_op)
69 {
70 static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
71 nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
72 nDst == DstDatas::Size() && nDst == DstDescs::Size() &&
73 nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(),
74 "wrong!");
75
76 static_for<0, nSrc, 1>{}([&](auto i) {
77 static_assert(
78 nDim == remove_cvref_t<tuple_element_t<i.value, SrcDescs>>::GetNumOfDimension(),
79 "wrong!");
80 });
81
82 static_for<0, nDst, 1>{}([&](auto i) {
83 static_assert(
84 nDim == remove_cvref_t<tuple_element_t<i.value, DstDescs>>::GetNumOfDimension(),
85 "wrong!");
86 });
87
88 static_assert(nDim == ThreadClusterLengths::Size() &&
89 nDim == ThreadClusterArrangeOrder::Size() &&
90 nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
91 "wrong! nDim not consistent");
92
93 static_assert(
94 is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
95 "wrong! threads should be mapped to cover entire slicing window");
96
97 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
98 "wrong! ThreadGroup::GetNumOfThread() too small");
99
100 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
101 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
102 {
103 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
104 make_multi_index(ThreadGroup::GetThreadId()));
105
106 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
107
108 const auto src_thread_slice_origins = generate_tuple(
109 [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; },
110 Number<nSrc>{});
111
112 const auto dst_thread_slice_origins = generate_tuple(
113 [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; },
114 Number<nDst>{});
115
116 threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
117 threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
118 }
119 }
120
121 template <typename SrcBuffers, index_t ThreadScratchId = 0>
122 __device__ void RunRead(const SrcDescs& src_descs,
123 const SrcBuffers& src_bufs,
125 {
126 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
127 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
128 {
129 threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
130 }
131 }
132
133 template <typename T>
134 using is_tuple = decltype(ck::declval<T&>().IsTuple());
135
136 template <typename DstBuffers, index_t ThreadScratchId = 0>
137 __device__ void RunWrite(const DstDescs& dst_descs,
138 DstBuffers dst_bufs,
140 {
141 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
142 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
143 {
144 if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
145 threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
146 else
147 threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id);
148 }
149 }
150
151 template <typename SrcBuffers, typename DstBuffers>
152 __device__ void Run(const SrcDescs& src_descs,
153 const SrcBuffers& src_bufs,
154 const DstDescs& dst_descs,
155 DstBuffers dst_bufs)
156 {
157 RunRead(src_descs, src_bufs);
158 RunWrite(dst_descs, dst_bufs);
159 }
160
161 template <index_t ISrc>
162 __device__ void
163 MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)
164 {
165 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
166 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
167 {
168 threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step);
169 }
170 }
171
172 __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step)
173 {
174 static_for<0, SrcDescs::Size(), 1>{}(
175 [&](auto i) { MoveSrcSliceWindow(src_descs, i, step); });
176 }
177
178 template <index_t IDst>
179 __device__ void
180 MoveDstSliceWindow(const DstDescs& dst_descs, Number<IDst> iDst, const Index& step)
181 {
182 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
183 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
184 {
185 threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step);
186 }
187 }
188
189 __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step)
190 {
191 static_for<0, DstDescs::Size(), 1>{}(
192 [&](auto i) { MoveDstSliceWindow(dst_descs, i, step); });
193 }
194
195 private:
196 static constexpr auto thread_cluster_desc_ =
197 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
198
199 using ThreadwiseTransfer =
200 ThreadwiseTensorSliceTransfer_v7r2<SrcDatas,
201 DstDatas,
202 SrcDescs,
203 DstDescs,
204 ElementwiseOperation,
205 DstInMemOps,
206 decltype(thread_slice_lengths),
207 SrcDimAccessOrder,
208 DstDimAccessOrder,
209 SrcVectorDim,
210 DstVectorDim,
211 SrcScalarPerVector,
212 DstScalarPerVector,
213 ThreadTransferSrcResetCoordinateAfterRunFlags,
214 ThreadTransferDstResetCoordinateAfterRunFlags,
215 NumThreadScratch>;
216
217 ThreadwiseTransfer threadwise_transfer_;
218};
219
220} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
decltype(ck::declval< T & >().IsTuple()) is_tuple
Definition tuple_helper.hpp:176
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
static constexpr index_t nDst
Definition thread_group_tensor_slice_transfer_v7r2.hpp:52
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v7r2.hpp:48
__device__ void MoveDstSliceWindow(const DstDescs &dst_descs, const Index &step)
Definition thread_group_tensor_slice_transfer_v7r2.hpp:189
__device__ void MoveDstSliceWindow(const DstDescs &dst_descs, Number< IDst > iDst, const Index &step)
Definition thread_group_tensor_slice_transfer_v7r2.hpp:180
__device__ constexpr ThreadGroupTensorSliceTransfer_v7r2(const SrcDescs &src_descs, const StaticallyIndexedArray< Index, nSrc > &src_block_slice_origins, const DstDescs &dst_descs, const StaticallyIndexedArray< Index, nDst > &dst_block_slice_origins, const ElementwiseOperation &element_op)
Definition thread_group_tensor_slice_transfer_v7r2.hpp:58
__device__ void RunWrite(const DstDescs &dst_descs, DstBuffers dst_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v7r2.hpp:137
__device__ void MoveSrcSliceWindow(const SrcDescs &src_descs, const Index &step)
Definition thread_group_tensor_slice_transfer_v7r2.hpp:172
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v7r2.hpp:56
decltype(ck::declval< T & >().IsTuple()) is_tuple
Definition thread_group_tensor_slice_transfer_v7r2.hpp:134
static constexpr index_t nSrc
Definition thread_group_tensor_slice_transfer_v7r2.hpp:51
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v7r2.hpp:54
__device__ void MoveSrcSliceWindow(const SrcDescs &src_descs, Number< ISrc > iSrc, const Index &step)
Definition thread_group_tensor_slice_transfer_v7r2.hpp:163
__device__ void Run(const SrcDescs &src_descs, const SrcBuffers &src_bufs, const DstDescs &dst_descs, DstBuffers dst_bufs)
Definition thread_group_tensor_slice_transfer_v7r2.hpp:152
__device__ void RunRead(const SrcDescs &src_descs, const SrcBuffers &src_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v7r2.hpp:122
__device__ void RunWrite(const DstDescs &dst_descs, DstBuffers dst_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r2.hpp:380
__device__ void RunRead(const SrcDescs &src_descs, const SrcBuffers &src_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r2.hpp:145
Definition functional2.hpp:33