device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp Source File

device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp Source File#

Composable Kernel: device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp Source File
device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <functional>
7#include <iostream>
8#include <iterator>
9#include <numeric>
10#include <sstream>
11
25
26namespace ck {
27namespace tensor_operation {
28namespace device {
29
30namespace {
31
32struct ComputePtrOffsetOfStridedBatch
33{
34 ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC)
35 : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
36 {
37 }
38
39 __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
40 {
41 return g_idx * static_cast<long_index_t>(BatchStrideA_);
42 }
43
44 __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
45 {
46 return g_idx * static_cast<long_index_t>(BatchStrideB_);
47 }
48
49 __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
50 {
51 return g_idx * static_cast<long_index_t>(BatchStrideC_);
52 }
53
54 index_t BatchStrideA_;
55 index_t BatchStrideB_;
56 index_t BatchStrideC_;
57};
58
59/*
60 * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
61 *
62 * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
63 * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
64 * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
65 * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
66 * limitations.
67 *
68 * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
69 * returns the 2D index of the tile that it computes. \see
70 * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
71 *
72 * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
73 * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
74 * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
75 * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
76 * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
77 * pointer offset into \p ComputePtrOffsetOfStridedBatch.
78 *
79 * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
80 * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
81 * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
82 *
83 */
84template <typename GridwiseGemm,
85 typename ABDataType,
86 typename CDataType,
87 typename AGridDesc_K0_M0_M1_K1,
88 typename BGridDesc_K0_N0_N1_K1,
89 typename CGridDesc_M0_M10_M11_N0_N10_N11,
90 typename Block2CTileMap,
91 typename ComputePtrOffsetOfBatch,
92 bool HasMainKBlockLoop,
93 bool HasDoubleTailKBlockLoop>
94__global__ void
95#if CK_USE_LAUNCH_BOUNDS
97#endif
98 kernel_grouped_conv_fwd_dl(
99 const ABDataType* __restrict__ p_a_grid,
100 const ABDataType* __restrict__ p_b_grid,
101 CDataType* __restrict__ p_c_grid,
102 const index_t batch_count,
103 const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
104 const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
105 const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
106 const Block2CTileMap block_2_ctile_map,
107 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
108{
109#if(defined(__gfx906__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__))
110 // offset base pointer for each work-group
111 const index_t num_blocks_per_batch =
112 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
113 const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
114
115 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
116 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
117 const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
118 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
119 const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
120 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
121
122 constexpr index_t shared_block_size =
123 GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
124
125 __shared__ ABDataType p_shared[shared_block_size];
126
127 GridwiseGemm::Run(p_a_grid + a_batch_offset,
128 p_b_grid + b_batch_offset,
129 p_c_grid + c_batch_offset,
130 p_shared,
131 a_grid_desc_k0_m0_m1_k1,
132 b_grid_desc_k0_n0_n1_k1,
133 c_grid_desc_m0_m10_m11_n0_n10_n11,
134 block_2_ctile_map,
135 integral_constant<bool, HasMainKBlockLoop>{},
136 integral_constant<bool, HasDoubleTailKBlockLoop>{});
137#else
138 ignore = p_a_grid;
139 ignore = p_b_grid;
140 ignore = p_c_grid;
141 ignore = batch_count;
142 ignore = a_grid_desc_k0_m0_m1_k1;
143 ignore = b_grid_desc_k0_n0_n1_k1;
144 ignore = c_grid_desc_m0_m10_m11_n0_n10_n11;
145 ignore = compute_ptr_offset_of_batch;
146 ignore = block_2_ctile_map;
147
148 compute_ptr_offset_of_batch.GetAPtrOffset(0);
149 compute_ptr_offset_of_batch.GetBPtrOffset(0);
150 compute_ptr_offset_of_batch.GetCPtrOffset(0);
151#endif
152}
153
154} // namespace
155
156//
157// @brief Device Convolution operation.
158//
159// Supports:
160// @li Forward convolution with up to 3 spatial dimentions
161// @li Input tensor in GNWC data format
162// @li Weight tensor in GKXC data format
163// @li Output tensor in GNWK data format
164//
165// 1D:
166// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
167// 2D:
168// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
169// 3D:
170// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
171//
172template <
173 index_t NDimSpatial,
174 typename ADataType,
175 typename BDataType,
176 typename CDataType,
177 typename AccDataType,
178 typename ALayout,
179 typename BLayout,
180 typename CLayout,
181 typename AElementwiseOperation,
182 typename BElementwiseOperation,
183 typename CElementwiseOperation,
184 ConvolutionForwardSpecialization ConvForwardSpecialization,
185 GemmSpecialization GemmSpec,
186 index_t BlockSize,
187 index_t MPerBlock,
188 index_t NPerBlock,
189 index_t K0PerBlock,
190 index_t K1,
191 index_t M1PerThread,
192 index_t N1PerThread,
193 index_t KPerThread,
194 typename M1N1ThreadClusterM1Xs,
195 typename M1N1ThreadClusterN1Xs,
196 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
197 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
198 typename ABlockTransferThreadClusterArrangeOrder,
199 typename ABlockTransferSrcAccessOrder,
200 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
201 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
202 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
203 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
204 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
205 typename BBlockTransferThreadClusterArrangeOrder,
206 typename BBlockTransferSrcAccessOrder,
207 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
208 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
209 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
210 typename CThreadTransferSrcDstAccessOrder,
211 index_t CThreadTransferSrcDstVectorDim,
212 index_t CThreadTransferDstScalarPerVector,
217 bool> = false>
219 ALayout,
220 BLayout,
221 CLayout,
222 ADataType,
223 BDataType,
224 CDataType,
225 AElementwiseOperation,
226 BElementwiseOperation,
227 CElementwiseOperation>
228{
230
231 static constexpr auto I0 = Number<0>{};
232 static constexpr auto I1 = Number<1>{};
233 static constexpr auto I2 = Number<2>{};
234 static constexpr auto I3 = Number<3>{};
235
237
238 static constexpr auto matrix_padder =
239 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
240
241 template <typename ALay>
242 static auto
244 {
245 const auto in_gemmmraw_gemmkraw_desc =
246 conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
247
248 const auto in_gemmm_gemmk_desc =
249 matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
250
251 const auto M = in_gemmm_gemmk_desc.GetLength(I0);
252 const auto K = in_gemmm_gemmk_desc.GetLength(I1);
253
254 const auto AK0 = K / K1;
255
257 in_gemmm_gemmk_desc,
261 }
262
263 template <typename BLay>
264 static auto
266 {
267 const auto wei_gemmnraw_gemmkraw_desc =
268 conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
269
270 const auto wei_gemmn_gemmk_desc =
271 matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
272
273 const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
274 const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
275
276 const auto BK0 = K / K1;
277
279 wei_gemmn_gemmk_desc,
283 }
284
285 template <typename CLay>
286 static auto MakeCGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
287 {
288 const auto out_gemmmraw_gemmnraw_desc =
289 conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>();
290
291 const auto out_gemmm_gemmn_desc =
292 matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
293
294 return out_gemmm_gemmn_desc;
295 }
296
297 // desc for problem definition
305
306 // GridwiseGemm
309 ADataType,
310 AccDataType,
311 CDataType,
316 MPerBlock,
317 NPerBlock,
318 K0PerBlock,
319 K1,
320 M1PerThread,
321 N1PerThread,
322 KPerThread,
323 M1N1ThreadClusterM1Xs,
324 M1N1ThreadClusterN1Xs,
325 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
326 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
327 ABlockTransferThreadClusterArrangeOrder,
328 ABlockTransferSrcAccessOrder,
329 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
330 ABlockTransferSrcVectorTensorContiguousDimOrder,
331 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
332 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
333 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
334 BBlockTransferThreadClusterArrangeOrder,
335 BBlockTransferSrcAccessOrder,
336 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
337 BBlockTransferSrcVectorTensorContiguousDimOrder,
338 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
339 CThreadTransferSrcDstAccessOrder,
340 CThreadTransferSrcDstVectorDim,
341 CThreadTransferDstScalarPerVector>;
342
351
352 // Argument
353 struct Argument : public BaseArgument
354 {
355 Argument(const void* p_a,
356 const void* p_b,
357 void* p_c,
358 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
359 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
360 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
361 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
362 const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
363 const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
364 const std::array<index_t, NDimSpatial>& conv_filter_strides,
365 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
366 const std::array<index_t, NDimSpatial>& input_left_pads,
367 const std::array<index_t, NDimSpatial>& input_right_pads,
368 const AElementwiseOperation& a_element_op,
369 const BElementwiseOperation& b_element_op,
370 const CElementwiseOperation& c_element_op)
371 : p_a_grid_{static_cast<const ADataType*>(p_a)},
372 p_b_grid_{static_cast<const BDataType*>(p_b)},
373 p_c_grid_{static_cast<CDataType*>(p_c)},
374 num_group_{a_g_n_c_wis_lengths[0]},
375 conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
376 a_g_n_c_wis_strides,
377 b_g_k_c_xs_lengths,
378 b_g_k_c_xs_strides,
379 e_g_n_k_wos_lengths,
380 e_g_n_k_wos_strides,
381 conv_filter_strides,
382 conv_filter_dilations,
383 input_left_pads,
384 input_right_pads},
396 a_g_n_c_wis_strides[0], b_g_k_c_xs_strides[0], c_g_n_k_wos_strides[0]},
397 a_element_op_{a_element_op},
398 b_element_op_{b_element_op},
399 c_element_op_{c_element_op},
400 a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
401 a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
402 b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
403 b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
404 c_g_n_k_wos_lengths_{c_g_n_k_wos_lengths},
405 c_g_n_k_wos_strides_{c_g_n_k_wos_strides},
406 conv_filter_strides_{conv_filter_strides},
407 conv_filter_dilations_{conv_filter_dilations},
408 input_left_pads_{input_left_pads},
409 input_right_pads_{input_right_pads}
410 {
411 // A/B/E Batch Stride
412 compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
413 compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
414 compute_ptr_offset_of_batch_.BatchStrideC_ = c_g_n_k_wos_strides[0];
415
416 // populate desc for Ds/E
419 {
420
427
429 }
430 }
431
432 void Print() const
433 {
434 std::cout << "A[K0, M, K1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
435 std::cout << "B[K0, N, K1]: " << b_grid_desc_bk0_n_bk1_ << std::endl;
436 std::cout << "C[M, N]: " << c_grid_desc_m_n_ << std::endl;
437 std::cout << "num_group: " << num_group_ << std::endl;
438
439 std::cout << "A[k0, m0, m1, k1]: " << a_grid_desc_k0_m0_m1_k1_ << std::endl;
440 std::cout << "B[k0, n0, n1, k1]: " << b_grid_desc_k0_n0_n1_k1_ << std::endl;
441 std::cout << "A[m0, m10, m11, n0, n10, n11]: " << c_grid_desc_m0_m10_m11_n0_n10_n11_
442 << std::endl;
443 }
444
445 // private:
446 // pointers
447 const ADataType* p_a_grid_;
448 const BDataType* p_b_grid_;
449 CDataType* p_c_grid_;
450
451 // tensor descriptors for problem definiton
453
455
459
460 // tensor descriptors for block/thread-wise copy
464
465 // block-to-e-tile map
467
468 // for computing batch offset
469 ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
470
471 // element-wise op
472 AElementwiseOperation a_element_op_;
473 BElementwiseOperation b_element_op_;
474 CElementwiseOperation c_element_op_;
475
476 // for checking IsSupportedArgument()
477 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
478 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
479 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
480 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
481 std::array<index_t, NDimSpatial + 3> c_g_n_k_wos_lengths_;
482 std::array<index_t, NDimSpatial + 3> c_g_n_k_wos_strides_;
483 std::array<index_t, NDimSpatial> conv_filter_strides_;
484 std::array<index_t, NDimSpatial> conv_filter_dilations_;
485 std::array<index_t, NDimSpatial> input_left_pads_;
486 std::array<index_t, NDimSpatial> input_right_pads_;
487 };
488
489 // Invoker
490 struct Invoker : public BaseInvoker
491 {
493
494 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
495 {
496 // if(stream_config.log_level_ > 0)
497 {
498 arg.Print();
499 }
500
503 {
504 throw std::runtime_error(
505 "wrong! DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK has invalid setting");
506 }
507
508 const index_t grid_size =
510 arg.c_grid_desc_m_n_.GetLength(I1)) *
511 arg.num_group_;
512
513 auto launch_kernel = [&](auto has_main_k_block_loop,
514 auto has_double_tail_k_block_loop) {
515 constexpr bool has_main_loop = has_main_k_block_loop.value;
516 constexpr bool has_double_loop = has_double_tail_k_block_loop;
517
518 const auto kernel =
519 kernel_grouped_conv_fwd_dl<GridwiseGemm,
520 ADataType, // TODO: distiguish A/B datatype
521 CDataType,
526 ComputePtrOffsetOfStridedBatch,
527 has_main_loop,
528 has_double_loop>;
529
530 return launch_and_time_kernel(stream_config,
531 kernel,
532 dim3(grid_size),
533 dim3(BlockSize),
534 0,
535 arg.p_a_grid_,
536 arg.p_b_grid_,
537 arg.p_c_grid_,
538 arg.a_g_n_c_wis_lengths_[0], // Group count
544 };
545
546 const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
547 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
548 const bool has_double_tail_k_block_loop =
550
551 if(has_main_k_block_loop && has_double_tail_k_block_loop)
552 {
553 return launch_kernel(integral_constant<bool, true>{},
554 integral_constant<bool, true>{});
555 }
556 else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
557 {
558 return launch_kernel(integral_constant<bool, true>{},
559 integral_constant<bool, false>{});
560 }
561 else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
562 {
563 return launch_kernel(integral_constant<bool, false>{},
564 integral_constant<bool, true>{});
565 }
566 else
567 {
568 return launch_kernel(integral_constant<bool, false>{},
569 integral_constant<bool, false>{});
570 }
571 }
572
573 float Run(const BaseArgument* p_arg,
574 const StreamConfig& stream_config = StreamConfig{}) override
575 {
576 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
577 }
578 };
579
580 static bool IsSupportedArgument(const Argument& arg)
581 {
582 namespace ctc = tensor_layout::convolution;
583
584 // check device
585 if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
587 {
588 return false;
589 }
590
591 // check ConvolutionForwardSpecialization
592 if constexpr(ConvForwardSpecialization ==
594 {
595 // check if it's 1x1, stride=1 conv
596 for(index_t i = 0; i < NDimSpatial; ++i)
597 {
598 const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
599 const index_t ConvStride = arg.conv_filter_strides_[i];
600 const index_t LeftPad = arg.input_left_pads_[i];
601 const index_t RightPad = arg.input_right_pads_[i];
602
603 if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
604 {
605 std::cout << "Filter1x1Stride1Pad0 check: i = " << i << " X = " << X
606 << " ConvStride = " << ConvStride << " LeftPad = " << LeftPad
607 << " RightPad = " << RightPad << std::endl;
608 return false;
609 }
610 }
611 }
612 else if constexpr(ConvForwardSpecialization ==
614 {
615 // check if it's 1x1 conv
616 for(index_t i = 0; i < NDimSpatial; ++i)
617 {
618 const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
619 const index_t LeftPad = arg.input_left_pads_[i];
620 const index_t RightPad = arg.input_right_pads_[i];
621
622 if(!(X == 1 && LeftPad == 0 && RightPad == 0))
623 {
624 std::cout << "Filter1x1Stride1Pad0 check: i = " << i << " X = " << X
625 << " LeftPad = " << LeftPad << " RightPad = " << RightPad
626 << std::endl;
627 return false;
628 }
629 }
630 }
631
632 // check vector access of A
633 // FIXME: layout
639 {
640 auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
641 if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1)
642 {
643 return false;
644 }
645 if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0)
646 {
647 return false;
648 }
649
650 const index_t C = arg.a_g_n_c_wis_lengths_[2];
651
652 if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0)
653 {
654 return false;
655 }
656 }
657 else
658 {
659 return false;
660 }
661
662 // check vector access of B
663 // FIXME: layout
669
670 {
671 auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{};
672 if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1)
673 {
674 return false;
675 }
676 if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0)
677 {
678 return false;
679 }
680
681 const index_t C = arg.b_g_k_c_xs_lengths_[2];
682
683 if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0)
684 {
685 return false;
686 }
687 }
688 else
689 {
690 return false;
691 }
692
693 // check vector access of C
699 {
700 const index_t K = arg.c_g_n_k_wos_lengths_[2];
701
702 if(!(K % CThreadTransferDstScalarPerVector == 0 && CThreadTransferSrcDstVectorDim == 5))
703 {
704 return false;
705 }
706 }
707 else
708 {
709 return false;
710 }
711 // check Gridwise GEMM
714 }
715
716 bool IsSupportedArgument(const BaseArgument* p_arg) override
717 {
718 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
719 }
720
721 static auto MakeArgument(const void* p_a,
722 const void* p_b,
723 void* p_c,
724 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
725 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
726 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
727 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
728 const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
729 const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
730 const std::array<index_t, NDimSpatial>& conv_filter_strides,
731 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
732 const std::array<index_t, NDimSpatial>& input_left_pads,
733 const std::array<index_t, NDimSpatial>& input_right_pads,
734 const AElementwiseOperation& a_element_op,
735 const BElementwiseOperation& b_element_op,
736 const CElementwiseOperation& c_element_op)
737 {
738 return Argument{p_a,
739 p_b,
740 p_c,
741 a_g_n_c_wis_lengths,
742 a_g_n_c_wis_strides,
743 b_g_k_c_xs_lengths,
744 b_g_k_c_xs_strides,
745 c_g_n_k_wos_lengths,
746 c_g_n_k_wos_strides,
747 conv_filter_strides,
748 conv_filter_dilations,
749 input_left_pads,
750 input_right_pads,
751 a_element_op,
752 b_element_op,
753 c_element_op};
754 }
755
756 static auto MakeInvoker() { return Invoker{}; }
757
758 std::unique_ptr<BaseArgument>
759 MakeArgumentPointer(const void* p_a,
760 const void* p_b,
761 void* p_c,
762 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
763 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
764 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
765 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
766 const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
767 const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
768 const std::array<index_t, NDimSpatial>& conv_filter_strides,
769 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
770 const std::array<index_t, NDimSpatial>& input_left_pads,
771 const std::array<index_t, NDimSpatial>& input_right_pads,
772 const AElementwiseOperation& a_element_op,
773 const BElementwiseOperation& b_element_op,
774 const CElementwiseOperation& c_element_op) override
775 {
776 return std::make_unique<Argument>(p_a,
777 p_b,
778 p_c,
779 a_g_n_c_wis_lengths,
780 a_g_n_c_wis_strides,
781 b_g_k_c_xs_lengths,
782 b_g_k_c_xs_strides,
783 c_g_n_k_wos_lengths,
784 c_g_n_k_wos_strides,
785 conv_filter_strides,
786 conv_filter_dilations,
787 input_left_pads,
788 input_right_pads,
789 a_element_op,
790 b_element_op,
791 c_element_op);
792 }
793
794 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
795 {
796 return std::make_unique<Invoker>(Invoker{});
797 }
798
799 std::string GetTypeString() const override
800 {
801 auto str = std::stringstream();
802
803 // clang-format off
804 str << "DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK"
805 << "<"
806 << BlockSize << ", "
807 << MPerBlock << ", "
808 << NPerBlock << ", "
809 << K0PerBlock << ", "
810 << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
811 << K1 << ", "
812 << MPerXDL << ", "
813 << NPerXDL << ", "
814 << MXdlPerWave << ", "
815 << NXdlPerWave << ", "
816 << ABlockTransferSrcScalarPerVector << ", "
817 << ABlockTransferDstScalarPerVector_K1 << ", "
818 << BBlockTransferSrcScalarPerVector << ", "
819 << BBlockTransferDstScalarPerVector_K1 << ", "
820 << CShuffleMXdlPerWavePerShuffle << ", "
821 << CShuffleNXdlPerWavePerShuffle << ", "
822 << CBlockTransferScalarPerVector_NWaveNPerXdl
823 << ">";
824 // clang-format on
825
826 return str.str();
827 }
828};
829
830} // namespace device
831} // namespace tensor_operation
832} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition tensor_operation/gpu/device/tensor_layout.hpp:42
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization &s)
Definition convolution_forward_specialization.hpp:24
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
bool is_gfx103_supported()
Definition host_utility/device_prop.hpp:120
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_dl_v1r3.hpp:93
Definition multi_index_transform.hpp:196
Definition multi_index_transform.hpp:284
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:25
Definition device_base.hpp:197
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:354
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:457
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:461
std::array< index_t, NDimSpatial > input_left_pads_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:485
void Print() const
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:432
const ADataType * p_a_grid_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:447
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_strides_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:480
CDataType * p_c_grid_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:449
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:462
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_strides_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:478
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:456
DefaultBlock2CTileMap block_2_ctile_map_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:466
CElementwiseOperation c_element_op_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:474
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:469
AElementwiseOperation a_element_op_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:472
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_lengths_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:477
const BDataType * p_b_grid_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:448
std::array< index_t, NDimSpatial+3 > c_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:481
BElementwiseOperation b_element_op_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:473
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:479
index_t num_group_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:452
CGridDesc_M_N c_grid_desc_m_n_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:458
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:463
ConvToGemmFwdTransformer conv_to_gemm_transformer_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:454
Argument(const void *p_a, const void *p_b, void *p_c, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &c_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &c_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op)
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:355
std::array< index_t, NDimSpatial > input_right_pads_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:486
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:483
std::array< index_t, NDimSpatial+3 > c_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:482
std::array< index_t, NDimSpatial > conv_filter_dilations_
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:484
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:491
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:494
DeviceOp::Argument Argument
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:492
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:573
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:228
static auto MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:265
static constexpr auto I0
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:231
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:794
remove_cvref_t< decltype(MakeBGridDescriptor_BK0_N_BK1< BLayout >( dummy_conv_to_gemm_transformer))> BGridDesc_BK0_N_BK1
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:301
static constexpr auto I2
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:233
static constexpr auto I1
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:232
remove_cvref_t< decltype(MakeCGridDescriptor_M_N< CLayout >(dummy_conv_to_gemm_transformer))> CGridDesc_M_N
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:303
DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK DeviceOp
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:229
TransformConvFwdToGemm< NDimSpatial, ConvForwardSpecialization > ConvToGemmFwdTransformer
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:236
GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:307
static auto MakeInvoker()
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:756
static constexpr auto I3
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:234
static auto MakeCGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:286
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:716
std::string GetTypeString() const override
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:799
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &c_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &c_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op) override
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:759
static constexpr ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:298
static auto MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:243
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_BK0_N_BK1{})) BGridDesc_K0_N0_N1_K1
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:345
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:347
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})) DefaultBlock2CTileMap
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:349
remove_cvref_t< decltype(MakeAGridDescriptor_AK0_M_AK1< ALayout >( dummy_conv_to_gemm_transformer))> AGridDesc_AK0_M_AK1
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:299
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:580
static constexpr auto matrix_padder
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:238
static auto MakeArgument(const void *p_a, const void *p_b, void *p_c, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &c_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &c_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op)
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:721
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_AK0_M_AK1{})) AGridDesc_K0_M0_M1_K1
Definition device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp:343
Definition device_grouped_conv_fwd.hpp:31
Definition matrix_padder.hpp:180