threadwise_tensor_slice_transfer_v7r3_scatter.hpp Source File

threadwise_tensor_slice_transfer_v7r3_scatter.hpp Source File#

Composable Kernel: threadwise_tensor_slice_transfer_v7r3_scatter.hpp Source File
threadwise_tensor_slice_transfer_v7r3_scatter.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
14
15namespace ck {
16// Thread-level multi-source, multi-destination tensor slice data movement
17// Assume:
18// 1. All sources and destinations are DynamicBuffer
19// 2. Same VectorDim and ScalerPerVector for all sources and destinations
20// 3. DstInMemOps are per destination tensor
21// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
22// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
23// 6. Does not need to know src_descs and dst_descs at compile-time
24// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time,
25//
26// Does following things to avoid scratch memory issue
27// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer
28// 2. Pass tensor descritpors by reference (or tuple of references)
29// 3. Does not keep reference to tensor descriptor
30// 4. Does not construct new tensor coordinate when call Run()
31template <typename SrcDatas,
32 typename DstDatas,
33 typename SrcDescs,
34 typename DstDescs,
35 typename ElementwiseOperation,
36 typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
37 typename SliceLengths,
38 typename SrcDimAccessOrder,
39 typename DstDimAccessOrder,
40 index_t SrcVectorDim,
41 index_t DstVectorDim,
42 typename SrcScalarPerVectors,
43 index_t DstScalarPerVector,
44 typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
45 typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
46 typename IndexType,
47 index_t ScatterDim = 1,
48 bool OutputScatter = true,
49 index_t ScatterWeightIdx = 3,
50 index_t NumThreadScratch = 1>
52{
53 static constexpr auto I0 = Number<0>{};
54 static constexpr auto I1 = Number<1>{};
55 static constexpr auto I2 = Number<2>{};
56 static constexpr auto I3 = Number<3>{};
57
58 static constexpr auto SrcScalarPerVector = SrcScalarPerVectors{}[I0];
59
60 static constexpr index_t nDim = SliceLengths::Size();
61
62 static constexpr index_t nSrc = SrcDescs::Size();
63 static constexpr index_t nDst = DstDescs::Size();
64
66 static constexpr index_t scatter_num = SliceLengths{}.At(Number<ScatterDim>{});
67
68 // return a tuple of coordiantes for a tuple of tensor
69 template <typename Descs,
70 typename Indices,
71 enable_if_t<Descs::Size() == Indices::Size(), bool> = false>
72 static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices)
73 {
74 return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); },
75 Number<Descs::Size()>{});
76 }
77
80
81 // scalar per access on each dim
82 // FIXME: don't use lambda_scalar_per_access
85
88
90 SrcDimAccessOrder,
92 false>;
93
95 DstDimAccessOrder,
97 false>;
98
100 const SrcDescs& src_descs,
101 const StaticallyIndexedArray<Index, nSrc>& src_slice_origins,
102 const DstDescs& dst_descs,
103 const StaticallyIndexedArray<Index, nDst>& dst_slice_origins,
104 const ElementwiseOperation& element_op)
105 : src_coords_(MakeCoordinates(src_descs, src_slice_origins)),
106 dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)),
107 element_op_(element_op)
108 {
109 static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
110 "wrong! cannot evenly divide");
111
112 static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
113 "wrong! cannot evenly divide");
114 }
115
116 template <typename Indices, enable_if_t<SrcDescs::Size() == Indices::Size(), bool> = false>
117 __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs,
118 const Indices& src_slice_origin_idxs)
119 {
120 static_for<0, nSrc, 1>{}([&](auto i) {
121 src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]);
122 });
123 }
124
125 template <typename Indices, enable_if_t<DstDescs::Size() == Indices::Size(), bool> = false>
126 __device__ void SetDstSliceOrigins(const DstDescs& dst_descs,
127 const Indices& dst_slice_origin_idxs)
128 {
129 static_for<0, nDst, 1>{}([&](auto i) {
130 dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]);
131 });
132 }
133
134 template <typename DataTypes, index_t ScalarPerVector>
135 __device__ static auto generate_vectors()
136 {
137 auto data_types = DataTypes{};
138
139 constexpr index_t num = data_types.Size();
140
141 return generate_tuple(
142 [&](auto i) {
143 using DataType = remove_cvref_t<decltype(data_types[i])>;
144
146 },
147 Number<num>{});
148 }
149
150 // SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
151 // SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
152 template <typename SrcBuffers,
153 index_t ThreadScratchId = 0,
154 enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
155 __device__ void RunRead(const SrcDescs& src_descs,
156 const SrcBuffers& src_bufs,
158 {
159 // loop over space-filling curve
160 static_for<0, src_num_access, 1>{}([&](auto iAccess) {
163
164 bool oob_val = true;
165
166 // copy data from src_bufs into src_vectors
167 static_for<0, nSrc, 1>{}([&](auto i) {
168 using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
169
170 const bool is_src_valid =
172 src_coords_[i]);
173
174 oob_val = oob_val & is_src_valid;
175 src_vectors(i).template AsType<src_vector_t>()(I0) =
176 src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
177 });
178
179 constexpr auto get_elem_op_vec_len = []() {
180 if constexpr(is_detected<is_pack8_invocable_t, decltype(element_op_)>::value)
181 {
182 if constexpr(decltype(element_op_)::is_pack8_invocable)
183 return math::min(8, SrcScalarPerVector);
184 }
185 if constexpr(is_detected<is_pack4_invocable_t, decltype(element_op_)>::value)
186 {
187 if constexpr(decltype(element_op_)::is_pack4_invocable)
188 return math::min(4, SrcScalarPerVector);
189 }
190 if constexpr(is_detected<is_pack2_invocable_t, decltype(element_op_)>::value)
191 {
192 if constexpr(decltype(element_op_)::is_pack2_invocable)
193 return math::min(2, SrcScalarPerVector);
194 }
195 return 1;
196 };
197
198 constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
199
200 // apply pointwise function
201 static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) {
202 // get reference to src data
203 const auto src_data_refs = generate_tie(
204 // return type should be lvalue
205 [&](auto iSrc) -> const auto& {
206 using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
207
208 using elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
209
210 return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
211 },
212 Number<nSrc>{});
213
214 // get reference to dst data
215 auto dst_data_refs = generate_tie(
216 // return type should be lvalue
217 [&](auto iDst) -> auto& {
218 using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
219
220 using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
221
222 return elm_vectors(iDst).template AsType<elem_op_vec_t>()(i);
223 },
224 Number<nDst>{});
225
226 // apply pointwise function
227 // pointwise function signature:
228 // element_op_(dst_data_refs[I0],
229 // dst_data_refs[I1],
230 // ...,
231 // src_data_refs[I0],
232 // src_data_refs[I1],
233 // ...)
234 unpack2(element_op_, dst_data_refs, src_data_refs);
235 });
236
237 elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
238 oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val;
239
240 // move coordinate
241 if constexpr(iAccess.value != src_num_access - 1)
242 {
243 constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess);
244
245 static_for<0, nSrc, 1>{}([&](auto i) {
246 move_tensor_coordinate(src_descs[i],
247 src_coords_(i),
248 make_tensor_coordinate_step(src_descs[i], forward_step));
249 });
250 }
251 });
252
253 // move coordinate back to slice origin (or not)
254 static_for<0, nSrc, 1>{}([&](auto i) {
255 if constexpr(SrcResetCoordinateAfterRunFlags::At(i))
256 {
257 const auto src_reset_step =
259
260 move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step);
261 }
262 });
263 }
264
265#if 1
266 template <index_t ThreadScratchId = 0>
267 __device__ void OOBCheck(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
268 {
269 // loop over space-filling curve
270 static_for<0, src_num_access, 1>{}([&](auto iAccess) {
271 auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess];
272 auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess];
273
274 static_for<0, nDst, 1>{}([&](auto i) {
275 using elm_vector_t = typename remove_cvref_t<decltype(elm_vectors[i])>::type;
276 elm_vectors(i).template AsType<elm_vector_t>()(I0) =
277 oob_val ? elm_vectors(i).template AsType<elm_vector_t>()[I0] : elm_vector_t{0};
278 });
279
280 elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
281 });
282 }
283#endif
284
285 template <index_t ThreadScratchId = 0>
286 __device__ void
288 {
289 using DstData = remove_cvref_t<decltype(DstDatas{}[I0])>;
290
291 using ElmThreadScratch =
292 StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
293 DstData,
296 true>;
297 using DstThreadScratch =
298 StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
299 DstData,
300 DstScalarPerVector,
302 true>;
303
304 ElmThreadScratch elm_thread_scratch_;
305 DstThreadScratch dst_thread_scratch_;
306
307 elm_thread_scratch_.data_ =
308 bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_[thread_scratch_id]);
309
310 if constexpr(SrcVectorDim != DstVectorDim &&
311 ((is_same<half_t, remove_cvref_t<DstData>>::value &&
312 SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
313 (is_same<f8_t, remove_cvref_t<DstData>>::value &&
314 SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
315 (is_same<int8_t, remove_cvref_t<DstData>>::value &&
316 SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
317 {
318 // each transpose does
319 // DstScalarPerVector # of src vectors in src_thread_scratch_
320 // SrcScalarPerVector # of dst vectors in dst_thread_scratch_
321 constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
322 constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
323
324 // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
325 // TODO: make this logic generic for all scenario
326
327 constexpr auto src_scalar_step_in_vector = generate_sequence(
328 detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
329
330 constexpr auto dst_scalar_step_in_vector = generate_sequence(
331 detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
332
333 constexpr auto scalar_per_access = generate_sequence(
334 detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
336 DstVectorDim,
337 DstScalarPerVector>{},
338 Number<nDim>{});
339
340 constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
341
342 static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
343 constexpr auto data_idx = access_idx * scalar_per_access;
344
345 constexpr auto data_idx_seq = generate_sequence_v2(
346 [&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
347
350
351 // get DstScalarPerVector # of read-only references to src vectors from
352 // src_thread_scratch_
353 const auto src_vector_refs = generate_tie(
354 [&](auto i) -> const src_vector_t& {
355 // i increment corresponds to movement in DstVectorDim
356 return elm_thread_scratch_.GetVectorTypeReference(
357 data_idx_seq + i * dst_scalar_step_in_vector);
358 },
360
361 // get SrcScalarPerVector # of references to dst vectors from
362 // dst_thread_scratch_
363 auto dst_vector_refs = generate_tie(
364 [&](auto i) -> dst_vector_t& {
365 // i increment corresponds to movement in SrcVectorDim
366 return dst_thread_scratch_.GetVectorTypeReference(
367 data_idx_seq + i * src_scalar_step_in_vector);
368 },
370
371 // do data transpose
372 transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
373 src_vector_refs, dst_vector_refs);
374 });
375 }
376 else
377 {
378 static_ford<SliceLengths>{}(
379 [&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; });
380 }
381
382 dst_vectors_tuple_(thread_scratch_id) = bit_cast<DstVectorTuple>(dst_thread_scratch_.data_);
383 }
384
385 // DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
386 // DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
387 template <typename DstBuffers,
388 index_t ThreadScratchId = 0,
389 enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
390 __device__ void RunWrite(const DstDescs& dst_descs,
391 DstBuffers dst_bufs,
394 {
395 OOBCheck(thread_scratch_id);
396 TransposeFromElmToDst(thread_scratch_id);
397
398 // loop over space-filling curve
399 static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
400 auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
401 IndexType scatter_offset = 0;
402 if constexpr(OutputScatter)
403 {
404 constexpr auto iScatter =
406 scatter_offset = scatter_offsets(Number<iScatter>{});
407 }
408 // copy data from buf_vectors into dst_bufs
409 static_for<0, nDst, 1>{}([&](auto i) {
410 using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
411 IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
412 const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
413 constexpr InMemoryDataOperationEnum DstInMemOp =
414 static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
415 dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
416 dst_offset, is_dst_valid, dst_vectors[i].template AsType<dst_vector_t>()[I0]);
417 });
418
419 // move coordinate
420 if constexpr(iAccess.value != dst_num_access - 1)
421 {
422 constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess);
423
424 auto forward_step_scatter = [&]() constexpr {
425 Index step_;
426
427 static_for<0, nDim, 1>{}([&](auto i) {
428 step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : forward_step[i];
429 });
430
431 return step_;
432 }();
433 static_for<0, nDst, 1>{}([&](auto i) {
435 dst_descs[i],
436 dst_coords_(i),
437 make_tensor_coordinate_step(dst_descs[i], forward_step_scatter));
438 });
439 }
440 });
441
442 static_for<0, nDst, 1>{}([&](auto i) {
443 if constexpr(DstResetCoordinateAfterRunFlags::At(i))
444 {
445 const auto dst_reset_step =
447
448 move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step);
449 }
450 });
451 }
452
453 // SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
454 // SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
455 // DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
456 // DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
457 template <typename SrcBuffers,
458 typename DstBuffers,
459 enable_if_t<SrcDescs::Size() == SrcBuffers::Size() &&
460 DstDescs::Size() == DstBuffers::Size(),
461 bool> = false>
462 __device__ void Run(const SrcDescs& src_descs,
463 const SrcBuffers& src_bufs,
464 const DstDescs& dst_descs,
465 DstBuffers dst_bufs,
467 {
468 RunRead(src_descs, src_bufs);
469 RunWrite(dst_descs, dst_bufs, scatter_offsets);
470 }
471
472 __device__ static constexpr auto GetSrcCoordinateResetStep()
473 {
474 if constexpr(src_num_access == 0)
475 {
476 return typename SrcSpaceFillingCurve::Index{};
477 }
478 else
479 {
481 }
482 }
483
484 __device__ static constexpr auto GetDstCoordinateResetStep()
485 {
486 if constexpr(dst_num_access == 0)
487 {
488 return typename DstSpaceFillingCurve::Index{};
489 }
490 else
491 {
492 constexpr auto reset_step =
494 auto reset_step_scatter = [&]() constexpr {
495 Index step_;
496 static_for<0, nDim, 1>{}([&](auto i) {
497 step_(i) =
498 (i.value == ScatterDim && OutputScatter) ? 0 : reset_step[Number<i>{}];
499 });
500
501 return step_;
502 }();
503 return reset_step_scatter;
504 }
505 }
506
507 __device__ static constexpr auto GetSrcThreadScratchDescriptor()
508 {
509 // constexpr auto src_scalar_per_access = generate_sequence(
510 // detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{},
511 // Number<nDim>{});
512
513 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
514
515 constexpr auto src_access_lengths_and_vector_length = container_push_back(
517
518 // 1st stage of transforms
519 constexpr auto desc0 =
520 make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
521
522 // 2nd stage of transforms
523 constexpr auto transforms = generate_tuple(
524 [&](auto i) {
525 if constexpr(i == SrcVectorDim)
526 {
528 make_tuple(src_access_lengths_and_vector_length[i],
529 src_access_lengths_and_vector_length[Number<nDim>{}]));
530 }
531 else
532 {
533 return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
534 }
535 },
536 Number<nDim>{});
537
538 constexpr auto low_dim_idss = generate_tuple(
539 [&](auto i) {
540 if constexpr(i == SrcVectorDim)
541 {
542 return Sequence<i.value, nDim>{};
543 }
544 else
545 {
546 return Sequence<i.value>{};
547 }
548 },
549 Number<nDim>{});
550
551 constexpr auto up_dim_idss =
552 generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
553
554 return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
555 }
556
557 __device__ static constexpr auto GetDstThreadScratchDescriptor()
558 {
559 // 1st stage of transforms
560 // constexpr auto dst_scalar_per_access = generate_sequence(
561 // detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{},
562 // Number<nDim>{});
563
564 constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
565
566 constexpr auto dst_access_lengths_and_vector_length = container_push_back(
568
569 constexpr auto desc0 =
570 make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
571
572 // 2nd stage of transforms
573 constexpr auto transforms = generate_tuple(
574 [&](auto i) {
575 if constexpr(i == DstVectorDim)
576 {
578 make_tuple(dst_access_lengths_and_vector_length[i],
579 dst_access_lengths_and_vector_length[Number<nDim>{}]));
580 }
581 else
582 {
583 return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
584 }
585 },
586 Number<nDim>{});
587
588 constexpr auto low_dim_idss = generate_tuple(
589 [&](auto i) {
590 if constexpr(i == DstVectorDim)
591 {
592 return Sequence<i.value, nDim>{};
593 }
594 else
595 {
596 return Sequence<i.value>{};
597 }
598 },
599 Number<nDim>{});
600
601 constexpr auto up_dim_idss =
602 generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
603
604 return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
605 }
606
607 // src_slice_origin_step_idx need to be known at compile-time, for performance reason
608 template <index_t ISrc>
609 __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
610 Number<ISrc> iSrc,
611 const Index& src_slice_origin_step_idx)
612 {
613 // if src coord was not reset by RunRead(), then need to adjust the step here
614 const auto adjusted_step_idx =
615 SrcResetCoordinateAfterRunFlags::At(iSrc)
616 ? src_slice_origin_step_idx
617 : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
618
619 // is it OK to construct a new step every time?
620 const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx);
621
622 move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step);
623 }
624
625 // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
626 template <index_t IDst>
627 __device__ void MoveDstSliceWindow(const DstDescs& dst_descs,
628 Number<IDst> iDst,
629 const Index& dst_slice_origin_step_idx)
630 {
631 // if dst coord was not reset by Run(), then need to adjust the step here
632 const auto adjusted_step_idx =
633 DstResetCoordinateAfterRunFlags::At(iDst)
634 ? dst_slice_origin_step_idx
635 : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
636
637 auto adjusted_step_idx_scatter = [&]() {
638 Index step_;
639 static_for<0, nDim, 1>{}([&](auto i) {
640 step_(i) =
641 (i.value == ScatterDim && OutputScatter) ? 0 : adjusted_step_idx[Number<i>{}];
642 });
643
644 return step_;
645 }();
646 // is it OK to construct a new step every time?
647 const auto adjusted_step =
648 make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx_scatter);
649
650 move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
651 }
652
653 private:
654 using SrcVectorsType = decltype(generate_vectors<SrcDatas, SrcScalarPerVector>());
655 using ElmVectorsType = decltype(generate_vectors<DstDatas, SrcScalarPerVector>());
656 using DstVectorsType = decltype(generate_vectors<DstDatas, DstScalarPerVector>());
657
658 static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess();
659 static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess();
660
663
666
669
670 SrcCoords src_coords_;
671 DstCoords dst_coords_;
672 const ElementwiseOperation element_op_;
673};
674
675} // namespace ck
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
decltype(ck::declval< T & >().is_pack8_invocable) is_pack8_invocable_t
Definition is_detected.hpp:43
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition tensor_description/tensor_descriptor.hpp:444
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition tensor_description/tensor_descriptor.hpp:508
__host__ __device__ constexpr auto container_push_back(const Array< TData, NSize > &a, const TData &x)
Definition utility/container_helper.hpp:18
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
decltype(ck::declval< T & >().is_pack4_invocable) is_pack4_invocable_t
Definition is_detected.hpp:40
__host__ __device__ constexpr bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:560
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition utility/container_helper.hpp:380
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
decltype(ck::declval< T & >().is_pack2_invocable) is_pack2_invocable_t
Definition is_detected.hpp:37
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto generate_sequence(F, Number< N >)
Definition sequence_helper.hpp:18
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
static __device__ __host__ constexpr auto GetStepBetween(Number< AccessIdx1dBegin >, Number< AccessIdx1dEnd >)
Definition tensor_space_filling_curve.hpp:52
__host__ static __device__ constexpr index_t GetNumOfAccess()
Definition tensor_space_filling_curve.hpp:41
static __device__ __host__ constexpr Index GetIndex(Number< AccessIdx1d >)
Definition tensor_space_filling_curve.hpp:81
static __device__ __host__ constexpr auto GetForwardStep(Number< AccessIdx1d >)
Definition tensor_space_filling_curve.hpp:66
MultiIndex< nDim > Index
Definition tensor_space_filling_curve.hpp:23
__device__ void SetDstSliceOrigins(const DstDescs &dst_descs, const Indices &dst_slice_origin_idxs)
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:126
__device__ void SetSrcSliceOrigins(const SrcDescs &src_descs, const Indices &src_slice_origin_idxs)
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:117
__device__ void TransposeFromElmToDst(Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:287
__device__ constexpr ThreadwiseTensorSliceTransfer_v7r3_scatter(const SrcDescs &src_descs, const StaticallyIndexedArray< Index, nSrc > &src_slice_origins, const DstDescs &dst_descs, const StaticallyIndexedArray< Index, nDst > &dst_slice_origins, const ElementwiseOperation &element_op)
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:99
__device__ void MoveSrcSliceWindow(const SrcDescs &src_descs, Number< ISrc > iSrc, const Index &src_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:609
__device__ void RunWrite(const DstDescs &dst_descs, DstBuffers dst_bufs, StaticallyIndexedArray< IndexType, scatter_num > &scatter_offsets, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:390
static __device__ constexpr auto GetSrcThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:507
__device__ void OOBCheck(Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:267
static __device__ constexpr auto GetSrcCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:472
static __device__ auto generate_vectors()
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:135
static __device__ constexpr auto GetDstCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:484
static constexpr auto MakeCoordinates(const Descs &descs, const Indices &indices)
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:72
__device__ void RunRead(const SrcDescs &src_descs, const SrcBuffers &src_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:155
static __device__ constexpr auto GetDstThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:557
__device__ void MoveDstSliceWindow(const DstDescs &dst_descs, Number< IDst > iDst, const Index &dst_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:627
__device__ void Run(const SrcDescs &src_descs, const SrcBuffers &src_bufs, const DstDescs &dst_descs, DstBuffers dst_bufs, StaticallyIndexedArray< IndexType, scatter_num > &scatter_offsets)
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:462
Definition threadwise_tensor_slice_transfer_util.hpp:20
Definition functional2.hpp:33