threadwise_tensor_slice_transfer_v3r2.hpp Source File

threadwise_tensor_slice_transfer_v3r2.hpp Source File#

Composable Kernel: threadwise_tensor_slice_transfer_v3r2.hpp Source File
threadwise_tensor_slice_transfer_v3r2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
13
14namespace ck {
15
16// Assume:
17// 1. src_desc and dst_desc are not known at compile-time
18// 2. SrcBuffer and DstBuffer are DynamicBuffer
19// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
20// 4. Use thread buffer
21template <typename SliceLengths,
22 typename ElementwiseOperation,
23 typename DstInMemOps, // Sequence
24 typename SrcDatas,
25 typename DstDatas,
26 typename SrcDescs,
27 typename DstDescs,
28 typename SrcDimAccessOrder,
29 typename DstDimAccessOrder,
30 index_t SrcVectorDim,
31 index_t DstVectorDim,
32 typename SrcsScalarPerVector, // Sequence
33 typename DstsScalarPerVector, // Sequence
34 typename SrcsScalarStrideInVector, // Sequence
35 typename DstsScalarStrideInVector, // Sequence
36 typename SrcsResetCoordinateAfterRun, // control whether to move back src coordinate after
37 // each RunRead(), will be fused with
38 // MoveSrcSliceWindow to save addr computation
39 typename DstsResetCoordinateAfterRun, // control whether to move back dst coordinate after
40 // each RunWrite(), will be fused with
41 // MoveDstSliceWindow to save addr computation
42 index_t NumThreadScratch = 1>
44{
45 static constexpr index_t nDim = SliceLengths::Size();
47
48 static constexpr index_t nSrc = SrcDescs::Size();
49 static constexpr index_t nDst = DstDescs::Size();
50
51 // return a tuple of coordiantes for a tuple of tensor
52 template <typename Descs,
53 typename Indices,
54 enable_if_t<Descs::Size() == Indices::Size(), bool> = false>
55 static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices)
56 {
57 return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); },
58 Number<Descs::Size()>{});
59 }
60
63
64 static constexpr auto I0 = Number<0>{};
65
67 const SrcDescs& src_descs,
68 const StaticallyIndexedArray<Index, nSrc>& src_slice_origins,
69 const DstDescs& dst_descs,
70 const StaticallyIndexedArray<Index, nDst>& dst_slice_origins,
71 const ElementwiseOperation& element_op)
72 : src_coords_(MakeCoordinates(src_descs, src_slice_origins)),
73 dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)),
74 element_op_(element_op)
75 {
76 }
77
78 template <typename Indices, enable_if_t<SrcDescs::Size() == Indices::Size(), bool> = false>
79 __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs,
80 const Indices& src_slice_origin_idxs)
81 {
82 static_for<0, nSrc, 1>{}([&](auto src_i) {
83 src_coords_(src_i) =
84 make_tensor_coordinate(src_descs.At(src_i), src_slice_origin_idxs[src_i]);
85 });
86 }
87
88 template <typename Indices, enable_if_t<DstDescs::Size() == Indices::Size(), bool> = false>
89 __device__ void SetDstSliceOrigins(const DstDescs& dst_descs,
90 const Indices& dst_slice_origin_idxs)
91 {
92 static_for<0, nDst, 1>{}([&](auto dst_i) {
93 dst_coords_(dst_i) =
94 make_tensor_coordinate(dst_descs.At(dst_i), dst_slice_origin_idxs[dst_i]);
95 });
96 }
97
98 template <typename SrcBuffers, index_t ThreadScratchId = 0>
99 __device__ void RunRead(const SrcDescs& src_descs,
100 const SrcBuffers& src_bufs,
102 {
103 // scalar per access on each dim
104 // TODO: don't use lambda_scalar_per_access
105 constexpr auto src_scalar_per_access_tuple = generate_tuple(
106 [&](auto src_i) {
107 return generate_sequence(
109 SrcsScalarPerVector::At(src_i)>{},
110 Number<nDim>{});
111 },
112 Number<nSrc>{});
113
114 constexpr auto src_access_lengths_tuple = generate_tuple(
115 [&](auto src_i) {
116 return SliceLengths{} / src_scalar_per_access_tuple.At(src_i);
117 static_assert(
118 SliceLengths::At(SrcVectorDim) % SrcsScalarPerVector::At(src_i) == 0,
119 "SliceLengths[SrcVectorDim] must be divisible by SrcsScalarPerVector");
120 },
121 Number<nSrc>{});
122
123 constexpr auto src_dim_access_order = SrcDimAccessOrder{};
124
125 constexpr auto ordered_src_access_lengths_tuple = generate_tuple(
126 [&](auto src_i) {
127 return container_reorder_given_new2old(src_access_lengths_tuple.At(src_i),
128 src_dim_access_order);
129 },
130 Number<nSrc>{});
131
132 // make forward steps
133 const auto src_forward_steps_tuple = generate_tuple(
134 [&](auto src_i) {
135 return generate_tuple(
136 [&](auto i) {
137 Index forward_step_idx;
138
139 static_for<0, nDim, 1>{}([&](auto j) {
140 forward_step_idx(j) =
141 (i.value == j.value) ? src_scalar_per_access_tuple.At(src_i)[i] : 0;
142 });
143
144 return make_tensor_coordinate_step(src_descs.At(src_i), forward_step_idx);
145 },
146 Number<nDim>{});
147 },
148 Number<nSrc>{});
149
150 // make backward steps
151 const auto src_backward_steps_tuple = generate_tuple(
152 [&](auto src_i) {
153 return generate_tuple(
154 [&](auto i) {
155 Index backward_step_idx;
156
157 static_for<0, nDim, 1>{}([&](auto j) {
158 backward_step_idx(j) = (i.value == j.value)
159 ? -src_scalar_per_access_tuple.At(src_i)[i]
160 : 0;
161 });
162
163 return make_tensor_coordinate_step(src_descs.At(src_i), backward_step_idx);
164 },
165 Number<nDim>{});
166 },
167 Number<nSrc>{});
168
169 // loop over tensor and copy
170 static_for<0, nSrc, 1>{}([&](auto src_i) {
171 static_ford<remove_cvref_t<decltype(ordered_src_access_lengths_tuple.At(src_i))>>{}(
172 [&](auto ordered_src_access_idx) {
173 // judge move forward or move backward
174 constexpr auto forward_sweep = [&]() {
176
177 forward_sweep_(I0) = true;
178
179 static_for<1, nDim, 1>{}([&](auto i) {
180 index_t tmp = ordered_src_access_idx[I0];
181
182 static_for<1, i, 1>{}([&](auto j) {
183 tmp = tmp * ordered_src_access_lengths_tuple[j] +
184 ordered_src_access_idx[j];
185 });
186
187 forward_sweep_(i) = tmp % 2 == 0;
188 });
189
190 return forward_sweep_;
191 }();
192
193 // calculate src data index
194 constexpr auto src_data_idx = [&]() {
195 Index ordered_idx;
196
197 static_for<0, nDim, 1>{}([&](auto i) {
198 ordered_idx(i) = forward_sweep[i]
199 ? ordered_src_access_idx[i]
200 : ordered_src_access_lengths_tuple.At(src_i)[i] -
201 1 - ordered_src_access_idx[i];
202 });
203
204 return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
205 src_scalar_per_access_tuple.At(src_i);
206 }();
207
208 constexpr auto src_data_idx_seq =
209 generate_sequence_v2([&](auto i) { return Number<src_data_idx[i]>{}; },
210 Number<src_data_idx.Size()>{});
211
212 const bool is_src_valid =
214 src_descs.At(src_i), src_coords_.At(src_i));
215
217 SrcsScalarPerVector::At(src_i)>;
218 using src_vector_t = typename src_vector_type::type;
219
220 // copy data from src_buf into src_vector_container
221 auto src_vector_container =
222 src_vector_type{src_bufs.At(src_i).template Get<src_vector_t>(
223 src_coords_.At(src_i).GetOffset(), is_src_valid)};
224
225 // copy data from src_vector_container into src_thread_scratch_
226 src_thread_scratch_tuple_(thread_scratch_id)
227 .At(src_i)
228 .template SetAsType<src_vector_t>(
229 src_data_idx_seq,
230 src_vector_container.template AsType<src_vector_t>()[I0]);
231
232 constexpr auto move_on_dim = [&]() constexpr {
234
235 static_for<0, nDim, 1>{}([&](auto i) {
236 move_on_dim_(i) = ordered_src_access_idx[i] <
237 ordered_src_access_lengths_tuple.At(src_i)[i] - 1;
238
239 static_for<i + 1, nDim, 1>{}([&](auto j) {
240 move_on_dim_(i) &=
241 ordered_src_access_idx[j] ==
242 ordered_src_access_lengths_tuple.At(src_i)[j] - 1;
243 });
244 });
245
246 return move_on_dim_;
247 }();
248
249 // move src coord
250 static_for<0, nDim, 1>{}([&](auto i) {
251 if constexpr(move_on_dim[i])
252 {
253 if constexpr(forward_sweep[i])
254 {
256 src_descs.At(src_i),
257 src_coords_.At(src_i),
258 src_forward_steps_tuple.At(src_i)[src_dim_access_order[i]]);
259 }
260 else
261 {
263 src_descs.At(src_i),
264 src_coords_.At(src_i),
265 src_backward_steps_tuple.At(src_i)[src_dim_access_order[i]]);
266 }
267 }
268 });
269 });
270 });
271
272 static_for<0, nSrc, 1>{}([&](auto src_i) {
273 // move src coordinate back to slice origin (or not)
274 if constexpr(SrcsResetCoordinateAfterRun::At(src_i))
275 {
276 const auto src_reset_step = make_tensor_coordinate_step(
277 src_descs.At(src_i), GetSrcCoordinateResetStep<src_i>());
278
279 move_tensor_coordinate(src_descs.At(src_i), src_coords_.At(src_i), src_reset_step);
280 }
281 });
282 }
283
284 template <index_t ThreadScratchId>
285 __device__ void
287 {
288 // TODO: Add support for CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
289 // (it requires to add Elementwise support in transpose_vectors)
290 static_ford<SliceLengths>{}([&](auto idx) {
291 const auto src_data_refs = generate_tie(
292 [&](auto src_i) -> const auto& {
293 return src_thread_scratch_tuple_[thread_scratch_id].At(src_i)[idx];
294 },
295 Number<nSrc>{});
296
297 auto dst_data_refs = generate_tie(
298 [&](auto dst_i) -> auto& { return dst_thread_scratch_tuple_.At(dst_i)(idx); },
299 Number<nDst>{});
300 unpack2(element_op_, dst_data_refs, src_data_refs);
301 });
302 }
303
304 template <typename DstBuffers, index_t ThreadScratchId = 0>
305 __device__ void RunWrite(const DstDescs& dst_descs,
306 DstBuffers& dst_bufs,
308 {
309 // if there is transpose, it's done here
310 // TODO move this elsewhere
312
313 // src scalar per access on each dim
314 // TODO: don't use this
315 constexpr auto dst_scalar_per_access_tuple = generate_tuple(
316 [&](auto dst_i) {
317 return generate_sequence(
319 DstsScalarPerVector::At(dst_i)>{},
320 Number<nDim>{});
321 },
322 Number<nDst>{});
323
324 constexpr auto dst_access_lengths_tuple = generate_tuple(
325 [&](auto dst_i) { return SliceLengths{} / dst_scalar_per_access_tuple.At(dst_i); },
326 Number<nDst>{});
327
328 constexpr auto dst_dim_access_order = DstDimAccessOrder{};
329
330 constexpr auto ordered_dst_access_lengths_tuple = generate_tuple(
331 [&](auto dst_i) {
332 return container_reorder_given_new2old(dst_access_lengths_tuple.At(dst_i),
333 dst_dim_access_order);
334 },
335 Number<nDst>{});
336
337 // make forward steps
338 const auto dst_forward_steps_tuple = generate_tuple(
339 [&](auto dst_i) {
340 return generate_tuple(
341 [&](auto i) {
342 Index forward_step_idx;
343
344 static_for<0, nDim, 1>{}([&](auto j) {
345 forward_step_idx(j) =
346 (i.value == j.value) ? dst_scalar_per_access_tuple.At(dst_i)[i] : 0;
347 });
348
349 return make_tensor_coordinate_step(dst_descs.At(dst_i), forward_step_idx);
350 },
351 Number<nDim>{});
352 },
353 Number<nDst>{});
354
355 // make backward steps
356 const auto dst_backward_steps_tuple = generate_tuple(
357 [&](auto dst_i) {
358 return generate_tuple(
359 [&](auto i) {
360 Index backward_step_idx;
361
362 static_for<0, nDim, 1>{}([&](auto j) {
363 backward_step_idx(j) = (i.value == j.value)
364 ? -dst_scalar_per_access_tuple.At(dst_i)[i]
365 : 0;
366 });
367
368 return make_tensor_coordinate_step(dst_descs.At(dst_i), backward_step_idx);
369 },
370 Number<nDim>{});
371 },
372 Number<nDst>{});
373
374 // loop over tensor and copy
375 static_for<0, nDst, 1>{}([&](auto dst_i) {
376 static_ford<remove_cvref_t<decltype(ordered_dst_access_lengths_tuple.At(dst_i))>>{}(
377 [&](auto ordered_dst_access_idx) {
378 // judge move forward or move backward
379 constexpr auto forward_sweep = [&]() {
381
382 forward_sweep_(I0) = true;
383
384 static_for<1, nDim, 1>{}([&](auto i) {
385 index_t tmp = ordered_dst_access_idx[I0];
386
387 static_for<1, i, 1>{}([&](auto j) {
388 tmp = tmp * ordered_dst_access_lengths_tuple.At(dst_i)[j] +
389 ordered_dst_access_idx[j];
390 });
391
392 forward_sweep_(i) = tmp % 2 == 0;
393 });
394
395 return forward_sweep_;
396 }();
397
398 // calculate dst data index
399 constexpr auto dst_data_idx = [&]() {
400 Index ordered_idx;
401
402 static_for<0, nDim, 1>{}([&](auto i) {
403 ordered_idx(i) = forward_sweep[i]
404 ? ordered_dst_access_idx[i]
405 : ordered_dst_access_lengths_tuple.At(dst_i)[i] -
406 1 - ordered_dst_access_idx[i];
407 });
408
409 return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
410 dst_scalar_per_access_tuple.At(dst_i);
411 }();
412
413 constexpr auto dst_data_idx_seq =
414 generate_sequence_v2([&](auto i) { return Number<dst_data_idx[i]>{}; },
415 Number<dst_data_idx.Size()>{});
416
417 const bool is_dst_valid =
419 dst_descs.At(dst_i), dst_coords_.At(dst_i));
420
422 DstsScalarPerVector::At(dst_i)>;
423 using dst_vector_t = typename dst_vector_type::type;
424
425 // copy data from dst_thread_scratch_ into dst_vector_container
426 auto dst_vector_container = dst_vector_type{
427 dst_thread_scratch_tuple_.At(dst_i).template GetAsType<dst_vector_t>(
428 dst_data_idx_seq)};
429
430 constexpr InMemoryDataOperationEnum DstInMemOp =
431 static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(dst_i.value));
432
433 // copy data from dst_vector_container to dst_buf
434 dst_bufs.At(dst_i).template Update<DstInMemOp, dst_vector_t>(
435 dst_coords_.At(dst_i).GetOffset(),
436 is_dst_valid,
437 dst_vector_container.template AsType<dst_vector_t>()[I0]);
438
439 constexpr auto move_on_dim = [&]() constexpr {
441
442 static_for<0, nDim, 1>{}([&](auto i) {
443 move_on_dim_(i) = ordered_dst_access_idx[i] <
444 ordered_dst_access_lengths_tuple.At(dst_i)[i] - 1;
445
446 static_for<i + 1, nDim, 1>{}([&](auto j) {
447 move_on_dim_(i) &=
448 ordered_dst_access_idx[j] ==
449 ordered_dst_access_lengths_tuple.At(dst_i)[j] - 1;
450 });
451 });
452
453 return move_on_dim_;
454 }();
455
456 // move dst coord
457 static_for<0, nDim, 1>{}([&](auto i) {
458 if constexpr(move_on_dim[i])
459 {
460 if constexpr(forward_sweep[i])
461 {
463 dst_descs.At(dst_i),
464 dst_coords_.At(dst_i),
465 dst_forward_steps_tuple.At(dst_i)[dst_dim_access_order[i]]);
466 }
467 else
468 {
470 dst_descs.At(dst_i),
471 dst_coords_.At(dst_i),
472 dst_backward_steps_tuple.At(dst_i)[dst_dim_access_order[i]]);
473 }
474 }
475 });
476 });
477 });
478
479 // move dst coordinate back to slice origin (or not)
480 static_for<0, nDst, 1>{}([&](auto dst_i) {
481 if constexpr(DstsResetCoordinateAfterRun::At(dst_i))
482 {
483 const auto dst_reset_step = make_tensor_coordinate_step(
484 dst_descs.At(dst_i), GetDstCoordinateResetStep<dst_i>());
485
486 move_tensor_coordinate(dst_descs.At(dst_i), dst_coords_.At(dst_i), dst_reset_step);
487 }
488 });
489 }
490
491 template <index_t src_i>
492 __device__ static constexpr auto GetSrcCoordinateResetStep()
493 {
494 // scalar per access on each dim
495 // TODO: don't use lambda_scalar_per_access
496 constexpr auto src_scalar_per_access = generate_sequence(
497 detail::lambda_scalar_per_access<SrcVectorDim, SrcsScalarPerVector::At(src_i)>{},
498 Number<nDim>{});
499
500 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
501
502 constexpr auto src_dim_access_order = SrcDimAccessOrder{};
503
504 constexpr auto ordered_src_access_lengths =
505 container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
506
507 // judge move forward or move backward during the last iteration
508 constexpr auto forward_sweep = [&]() {
510
511 forward_sweep_(I0) = true;
512
513 static_for<1, nDim, 1>{}([&](auto i) {
514 index_t tmp = ordered_src_access_lengths[I0] - 1;
515
516 static_for<1, i, 1>{}([&](auto j) {
517 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
518 });
519
520 forward_sweep_(i) = tmp % 2 == 0;
521 });
522
523 return forward_sweep_;
524 }();
525
526 // calculate src data index after last iteration in RunRead(), if it has not being reset by
527 // RunRead()
528 constexpr auto src_data_idx = [&]() {
529 Index ordered_idx;
530
531 static_for<0, nDim, 1>{}([&](auto i) {
532 ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
533 });
534
535 return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
536 src_scalar_per_access;
537 }();
538
539 //
540 constexpr auto reset_src_data_step = [&]() {
541 Index reset_src_data_step_;
542
543 static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
544
545 return reset_src_data_step_;
546 }();
547
548 return reset_src_data_step;
549 }
550
551 template <index_t dst_i>
552 __device__ static constexpr auto GetDstCoordinateResetStep()
553 {
554 // scalar per access on each dim
555 // TODO: don't use lambda_scalar_per_access
556 constexpr auto dst_scalar_per_access = generate_sequence(
557 detail::lambda_scalar_per_access<DstVectorDim, DstsScalarPerVector::At(dst_i)>{},
558 Number<nDim>{});
559
560 constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
561
562 constexpr auto dst_dim_access_order = DstDimAccessOrder{};
563
564 constexpr auto ordered_dst_access_lengths =
565 container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
566
567 // judge move forward or move backward during the last iteration
568 constexpr auto forward_sweep = [&]() {
570
571 forward_sweep_(I0) = true;
572
573 static_for<1, nDim, 1>{}([&](auto i) {
574 index_t tmp = ordered_dst_access_lengths[I0] - 1;
575
576 static_for<1, i, 1>{}([&](auto j) {
577 tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
578 });
579
580 forward_sweep_(i) = tmp % 2 == 0;
581 });
582
583 return forward_sweep_;
584 }();
585
586 // calculate dst data index after last iteration in RunWrite(), if it has not being reset by
587 // RunWrite()
588 constexpr auto dst_data_idx = [&]() {
589 Index ordered_idx;
590
591 static_for<0, nDim, 1>{}([&](auto i) {
592 ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
593 });
594
595 return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
596 dst_scalar_per_access.At(dst_i);
597 }();
598
599 //
600 constexpr auto reset_dst_data_step = [&]() {
601 Index reset_dst_data_step_;
602
603 static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
604
605 return reset_dst_data_step_;
606 }();
607
608 return reset_dst_data_step;
609 }
610
611 // src_slice_origin_step_idx need to be known at compile-time, for performance reason
612 __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
613 const Index& src_slice_origin_step_idx)
614 {
615 static_for<0, nSrc, 1>{}([&](auto src_i) {
616 // if src coord was not reset by RunRead(), then need to adjust the step here
617 const auto adjusted_step_idx =
618 SrcsResetCoordinateAfterRun::At(src_i)
619 ? src_slice_origin_step_idx
620 : src_slice_origin_step_idx + GetSrcCoordinateResetStep<src_i>();
621
622 // is it OK to construct a new step every time?
623 const auto adjusted_step =
624 make_tensor_coordinate_step(src_descs.At(src_i), adjusted_step_idx);
625
626 move_tensor_coordinate(src_descs.At(src_i), src_coords_.At(src_i), adjusted_step);
627 });
628 }
629
630 // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
631 __device__ void MoveDstSliceWindow(const DstDescs& dst_descs,
632 const Index& dst_slice_origin_step_idx)
633 {
634 static_for<0, nDst, 1>{}([&](auto dst_i) {
635 // if dst coord was not reset by RunWrite(), then need to adjust the step here
636 const auto adjusted_step_idx =
637 DstsResetCoordinateAfterRun::At(dst_i)
638 ? dst_slice_origin_step_idx
639 : dst_slice_origin_step_idx + GetDstCoordinateResetStep<dst_i>();
640
641 // is it OK to construct a new step every time?
642 const auto adjusted_step =
643 make_tensor_coordinate_step(dst_descs.At(dst_i), adjusted_step_idx);
644
645 move_tensor_coordinate(dst_descs.At(dst_i), dst_coords_.At(dst_i), adjusted_step);
646 });
647 }
648
649 template <index_t src_i>
650 __device__ static constexpr auto GetSrcThreadScratchDescriptor()
651 {
652 constexpr auto src_scalar_per_access = generate_sequence(
653 detail::lambda_scalar_per_access<SrcVectorDim, SrcsScalarPerVector::At(src_i)>{},
654 Number<nDim>{});
655
656 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
657
658 constexpr auto src_access_lengths_and_vector_length =
660 Number<SrcsScalarPerVector::At(src_i)>{});
661
662 // 1st stage of transforms
663 constexpr auto desc0 =
664 make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
665
666 // 2nd stage of transforms
667 constexpr auto transforms = generate_tuple(
668 [&](auto i) {
669 if constexpr(i == SrcVectorDim)
670 {
672 make_tuple(src_access_lengths_and_vector_length[i],
673 src_access_lengths_and_vector_length[Number<nDim>{}]));
674 }
675 else
676 {
677 return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
678 }
679 },
680 Number<nDim>{});
681
682 constexpr auto low_dim_idss = generate_tuple(
683 [&](auto i) {
684 if constexpr(i == SrcVectorDim)
685 {
686 return Sequence<i.value, nDim>{};
687 }
688 else
689 {
690 return Sequence<i.value>{};
691 }
692 },
693 Number<nDim>{});
694
695 constexpr auto up_dim_idss =
696 generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
697
698 return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
699 }
700
701 template <index_t dst_i>
702 __device__ static constexpr auto GetDstThreadScratchDescriptor()
703 {
704 // 1st stage of transforms
705 constexpr auto dst_scalar_per_access = generate_sequence(
706 detail::lambda_scalar_per_access<DstVectorDim, DstsScalarPerVector::At(dst_i)>{},
707 Number<nDim>{});
708
709 constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
710
711 constexpr auto dst_access_lengths_and_vector_length =
713 Number<DstsScalarPerVector::At(dst_i)>{});
714
715 constexpr auto desc0 =
716 make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
717
718 // 2nd stage of transforms
719 constexpr auto transforms = generate_tuple(
720 [&](auto i) {
721 if constexpr(i == DstVectorDim)
722 {
724 make_tuple(dst_access_lengths_and_vector_length[i],
725 dst_access_lengths_and_vector_length[Number<nDim>{}]));
726 }
727 else
728 {
729 return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
730 }
731 },
732 Number<nDim>{});
733
734 constexpr auto low_dim_idss = generate_tuple(
735 [&](auto i) {
736 if constexpr(i == DstVectorDim)
737 {
738 return Sequence<i.value, nDim>{};
739 }
740 else
741 {
742 return Sequence<i.value>{};
743 }
744 },
745 Number<nDim>{});
746
747 constexpr auto up_dim_idss =
748 generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
749
750 return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
751 }
752
753 __device__ static constexpr auto MakeSrcThreadScratchTuple()
754 {
755 return generate_tuple(
756 [&](auto src_i) {
757 constexpr auto src_thread_scratch_desc =
759 using SrcThreadScratch =
762 SrcsScalarPerVector::At(src_i),
763 decltype(src_thread_scratch_desc),
764 true>;
765 return SrcThreadScratch{};
766 },
767 Number<nSrc>{});
768 }
769
770 __device__ static constexpr auto MakeDstThreadScratchTuple()
771 {
772 return generate_tuple(
773 [&](auto dst_i) {
774 constexpr auto dst_thread_scratch_desc =
776 using DstThreadScratch =
779 DstsScalarPerVector::At(dst_i),
780 decltype(dst_thread_scratch_desc),
781 true>;
782 return DstThreadScratch{};
783 },
784 Number<nDst>{});
785 }
786
787 private:
788 using SrcThreadScratchTuple = decltype(MakeSrcThreadScratchTuple());
789 using DstThreadScratchTuple = decltype(MakeDstThreadScratchTuple());
790
792
793 DstThreadScratchTuple dst_thread_scratch_tuple_;
794
795 SrcCoords src_coords_;
796 DstCoords dst_coords_;
797 const ElementwiseOperation element_op_;
798};
799
800} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition tensor_description/tensor_descriptor.hpp:444
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition tensor_description/tensor_descriptor.hpp:508
__host__ __device__ constexpr auto container_push_back(const Array< TData, NSize > &a, const TData &x)
Definition utility/container_helper.hpp:18
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:560
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition utility/container_helper.hpp:380
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__host__ __device__ constexpr auto container_reorder_given_old2new(const Array< TData, NSize > &old_array, Sequence< IRs... > old2new)
Definition utility/container_helper.hpp:54
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto generate_sequence(F, Number< N >)
Definition sequence_helper.hpp:18
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
__host__ __device__ constexpr auto container_reorder_given_new2old(const Array< TData, NSize > &old_array, Sequence< IRs... >)
Definition utility/container_helper.hpp:43
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
Definition utility/sequence.hpp:43
Definition static_tensor.hpp:93
__device__ void RunRead(const SrcDescs &src_descs, const SrcBuffers &src_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r2.hpp:99
static __device__ constexpr auto GetSrcCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v3r2.hpp:492
__device__ void TransferDataFromSrcThreadScratchToDstThreadScratch(Number< ThreadScratchId > thread_scratch_id)
Definition threadwise_tensor_slice_transfer_v3r2.hpp:286
static __device__ constexpr auto GetDstThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v3r2.hpp:702
__device__ void MoveDstSliceWindow(const DstDescs &dst_descs, const Index &dst_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v3r2.hpp:631
static __device__ constexpr auto MakeSrcThreadScratchTuple()
Definition threadwise_tensor_slice_transfer_v3r2.hpp:753
__device__ void MoveSrcSliceWindow(const SrcDescs &src_descs, const Index &src_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v3r2.hpp:612
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r2(const SrcDescs &src_descs, const StaticallyIndexedArray< Index, nSrc > &src_slice_origins, const DstDescs &dst_descs, const StaticallyIndexedArray< Index, nDst > &dst_slice_origins, const ElementwiseOperation &element_op)
Definition threadwise_tensor_slice_transfer_v3r2.hpp:66
static __device__ constexpr auto MakeDstThreadScratchTuple()
Definition threadwise_tensor_slice_transfer_v3r2.hpp:770
__device__ void SetDstSliceOrigins(const DstDescs &dst_descs, const Indices &dst_slice_origin_idxs)
Definition threadwise_tensor_slice_transfer_v3r2.hpp:89
static __device__ constexpr auto GetSrcThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v3r2.hpp:650
static __device__ constexpr auto GetDstCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v3r2.hpp:552
__device__ void SetSrcSliceOrigins(const SrcDescs &src_descs, const Indices &src_slice_origin_idxs)
Definition threadwise_tensor_slice_transfer_v3r2.hpp:79
static constexpr auto MakeCoordinates(const Descs &descs, const Indices &indices)
Definition threadwise_tensor_slice_transfer_v3r2.hpp:55
__device__ void RunWrite(const DstDescs &dst_descs, DstBuffers &dst_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r2.hpp:305
Definition threadwise_tensor_slice_transfer_util.hpp:20
Definition functional2.hpp:33
Definition functional3.hpp:97