26template <
typename GridwiseGemm,
29 typename AGridDesc_B_K0_M0_M1_K1,
30 typename BGridDesc_B_K0_N0_N1_K1,
31 typename CGridDesc_M0_M10_M11_N0_N10_N11,
32 typename Block2CTileMap,
33 typename ComputePtrOffsetOfBatch,
34 bool HasMainKBlockLoop,
35 bool HasDoubleTailKBlockLoop>
37#if CK_USE_LAUNCH_BOUNDS
41 const FloatAB* __restrict__ p_a_grid,
42 const FloatAB* __restrict__ p_b_grid,
43 FloatC* __restrict__ p_c_grid,
45 const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1,
46 const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1,
47 const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
48 const Block2CTileMap block_2_ctile_map,
49 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
51#if(defined(__gfx906__) || defined(__gfx103__) || defined(__gfx90a__) || defined(__gfx908__) || \
52 defined(__gfx94__) || defined(__gfx11__) || defined(__gfx12__))
53 const index_t num_blocks_per_batch =
54 __builtin_amdgcn_readfirstlane(
get_grid_size() / batch_count);
64 __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() /
sizeof(FloatAB)];
66 GridwiseGemm::template Run<HasMainKBlockLoop, HasDoubleTailKBlockLoop>(
67 p_a_grid + a_batch_offset,
68 p_b_grid + b_batch_offset,
69 p_c_grid + c_batch_offset,
71 a_grid_desc_kbatch_k0_m0_m1_k1,
72 b_grid_desc_kbatch_k0_n0_n1_k1,
73 c_grid_desc_m0_m10_m11_n0_n10_n11,
82 ignore = a_grid_desc_kbatch_k0_m0_m1_k1;
83 ignore = b_grid_desc_kbatch_k0_n0_n1_k1;
84 ignore = c_grid_desc_m0_m10_m11_n0_n10_n11;
85 ignore = block_2_ctile_map;
86 ignore = compute_ptr_offset_of_batch;
98 typename InElementwiseOperation,
99 typename WeiElementwiseOperation,
100 typename OutElementwiseOperation,
110 typename M1N1ThreadClusterM1Xs,
111 typename M1N1ThreadClusterN1Xs,
112 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
113 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
114 typename ABlockTransferThreadClusterArrangeOrder,
115 typename ABlockTransferSrcAccessOrder,
116 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
117 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
118 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
119 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
120 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
121 typename BBlockTransferThreadClusterArrangeOrder,
122 typename BBlockTransferSrcAccessOrder,
123 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
124 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
125 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
126 typename CThreadTransferSrcDstAccessOrder,
127 index_t CThreadTransferSrcDstVectorDim,
128 index_t CThreadTransferDstScalarPerVector>
136 InElementwiseOperation,
137 WeiElementwiseOperation,
138 OutElementwiseOperation>
179 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type =
false>
181 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
182 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
183 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
184 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
185 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
186 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
187 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
188 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
189 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
190 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
195 const index_t N = a_g_n_c_wis_lengths[
I1];
196 const index_t K = b_g_k_c_xs_lengths[
I1];
197 const index_t C = a_g_n_c_wis_lengths[
I2];
201 const index_t InLeftPadW = input_left_pads[
I0];
202 const index_t InRightPadW = input_right_pads[
I0];
203 const index_t ConvStrideW = conv_filter_strides[
I0];
204 const index_t ConvDilationW = conv_filter_dilations[
I0];
206 const auto InNStride = a_g_n_c_wis_strides[
I1];
207 const auto InCStride = a_g_n_c_wis_strides[
I2];
209 const auto WeiKStride = b_g_k_c_xs_strides[
I1];
210 const auto WeiCStride = b_g_k_c_xs_strides[
I2];
211 const auto OutKStride = e_g_n_k_wos_strides[
I2];
214 const index_t GemmKTotal = N * Wo;
215 const index_t GemmKBatch = batch_k;
220 if constexpr(ConvBackwardWeightSpecialization ==
227 const auto out_gemmkpad_gemmmpad_grid_desc =
229 out_gemmktotal_gemmm_grid_desc,
234 out_gemmkpad_gemmmpad_grid_desc,
245 const auto in_gemmkpad_gemmnpad_grid_desc =
247 in_gemmktotal_gemmn_grid_desc,
252 in_gemmkpad_gemmnpad_grid_desc,
263 const auto wei_gemmmpad_gemmnpad_grid_desc =
268 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
269 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
270 wei_gemmmpad_gemmnpad_grid_desc);
280 const auto out_gemmkpad_gemmmpad_grid_desc =
282 out_gemmktotal_gemmm_grid_desc,
287 out_gemmkpad_gemmmpad_grid_desc,
304 in_n_wip_c_grid_desc,
312 const auto in_gemmktotal_gemmn_grid_desc =
319 const auto in_gemmkpad_gemmnpad_grid_desc =
321 in_gemmktotal_gemmn_grid_desc,
326 in_gemmkpad_gemmnpad_grid_desc,
337 const auto wei_gemmmpad_gemmnpad_grid_desc =
342 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
343 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
344 wei_gemmmpad_gemmnpad_grid_desc);
348 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type =
false>
350 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
351 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
352 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
353 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
354 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
355 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
356 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
357 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
358 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
359 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
364 const index_t N = a_g_n_c_wis_lengths[
I1];
365 const index_t K = b_g_k_c_xs_lengths[
I1];
366 const index_t C = a_g_n_c_wis_lengths[
I2];
374 const index_t InLeftPadH = input_left_pads[
I0];
375 const index_t InLeftPadW = input_left_pads[
I1];
376 const index_t InRightPadH = input_right_pads[
I0];
377 const index_t InRightPadW = input_right_pads[
I1];
378 const index_t ConvStrideH = conv_filter_strides[
I0];
379 const index_t ConvStrideW = conv_filter_strides[
I1];
380 const index_t ConvDilationH = conv_filter_dilations[
I0];
381 const index_t ConvDilationW = conv_filter_dilations[
I1];
383 const auto InNStride = a_g_n_c_wis_strides[
I1];
384 const auto InCStride = a_g_n_c_wis_strides[
I2];
387 const auto WeiKStride = b_g_k_c_xs_strides[
I1];
388 const auto WeiCStride = b_g_k_c_xs_strides[
I2];
389 const auto OutKStride = e_g_n_k_wos_strides[
I2];
392 const index_t GemmKTotal = N * Ho * Wo;
393 const index_t GemmKBatch = batch_k;
398 if constexpr(ConvBackwardWeightSpecialization ==
405 const auto out_gemmkpad_gemmmpad_grid_desc =
407 out_gemmktotal_gemmm_grid_desc,
412 out_gemmkpad_gemmmpad_grid_desc,
423 const auto in_gemmkpad_gemmnpad_grid_desc =
425 in_gemmktotal_gemmn_grid_desc,
430 in_gemmkpad_gemmnpad_grid_desc,
441 const auto wei_gemmmpad_gemmnpad_grid_desc =
446 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
447 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
448 wei_gemmmpad_gemmnpad_grid_desc);
458 const auto out_gemmkpad_gemmmpad_grid_desc =
460 out_gemmktotal_gemmm_grid_desc,
465 out_gemmkpad_gemmmpad_grid_desc,
474 in_n_hi_wi_c_grid_desc,
483 in_n_hip_wip_c_grid_desc,
492 const auto in_gemmktotal_gemmn_grid_desc =
499 const auto in_gemmkpad_gemmnpad_grid_desc =
501 in_gemmktotal_gemmn_grid_desc,
506 in_gemmkpad_gemmnpad_grid_desc,
517 const auto wei_gemmmpad_gemmnpad_grid_desc =
522 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
523 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
524 wei_gemmmpad_gemmnpad_grid_desc);
529 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type =
false>
531 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
532 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
533 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
534 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
535 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
536 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
537 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
538 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
539 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
540 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
545 const index_t N = a_g_n_c_wis_lengths[
I1];
546 const index_t K = b_g_k_c_xs_lengths[
I1];
547 const index_t C = a_g_n_c_wis_lengths[
I2];
558 const index_t InLeftPadD = input_left_pads[
I0];
559 const index_t InLeftPadH = input_left_pads[
I1];
560 const index_t InLeftPadW = input_left_pads[
I2];
561 const index_t InRightPadD = input_right_pads[
I0];
562 const index_t InRightPadH = input_right_pads[
I1];
563 const index_t InRightPadW = input_right_pads[
I2];
564 const index_t ConvStrideD = conv_filter_strides[
I0];
565 const index_t ConvStrideH = conv_filter_strides[
I1];
566 const index_t ConvStrideW = conv_filter_strides[
I2];
567 const index_t ConvDilationD = conv_filter_dilations[
I0];
568 const index_t ConvDilationH = conv_filter_dilations[
I1];
569 const index_t ConvDilationW = conv_filter_dilations[
I2];
571 const auto InNStride = a_g_n_c_wis_strides[
I1];
572 const auto InCStride = a_g_n_c_wis_strides[
I2];
576 const auto WeiKStride = b_g_k_c_xs_strides[
I1];
577 const auto WeiCStride = b_g_k_c_xs_strides[
I2];
578 const auto OutKStride = e_g_n_k_wos_strides[
I2];
581 const index_t GemmKTotal = N * Do * Ho * Wo;
582 const index_t GemmKBatch = batch_k;
587 if constexpr(ConvBackwardWeightSpecialization ==
594 const auto out_gemmkpad_gemmmpad_grid_desc =
596 out_gemmktotal_gemmm_grid_desc,
601 out_gemmkpad_gemmmpad_grid_desc,
612 const auto in_gemmkpad_gemmnpad_grid_desc =
614 in_gemmktotal_gemmn_grid_desc,
619 in_gemmkpad_gemmnpad_grid_desc,
630 const auto wei_gemmmpad_gemmnpad_grid_desc =
635 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
636 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
637 wei_gemmmpad_gemmnpad_grid_desc);
645 make_tuple(InNStride, InDStride, InHStride, InWStride, InCStride));
648 const auto out_gemmkpad_gemmmpad_grid_desc =
650 out_gemmktotal_gemmm_grid_desc,
655 out_gemmkpad_gemmmpad_grid_desc,
664 in_n_di_hi_wi_c_grid_desc,
676 in_n_dip_hip_wip_c_grid_desc,
692 in_n_z_do_y_ho_x_wo_c_grid_desc,
698 const auto in_gemmkpad_gemmnpad_grid_desc =
700 in_gemmktotal_gemmn_grid_desc,
705 in_gemmkpad_gemmnpad_grid_desc,
716 const auto wei_gemmmpad_gemmnpad_grid_desc =
721 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
722 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
723 wei_gemmmpad_gemmnpad_grid_desc);
728 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type =
false>
732 {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1);
735 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type =
false>
739 {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1);
742 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type =
false>
780 M1N1ThreadClusterM1Xs,
781 M1N1ThreadClusterN1Xs,
782 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
783 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
784 ABlockTransferThreadClusterArrangeOrder,
785 ABlockTransferSrcAccessOrder,
786 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
787 ABlockTransferSrcVectorTensorContiguousDimOrder,
788 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
789 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
790 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
791 BBlockTransferThreadClusterArrangeOrder,
792 BBlockTransferSrcAccessOrder,
793 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
794 BBlockTransferSrcVectorTensorContiguousDimOrder,
795 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
796 CThreadTransferSrcDstAccessOrder,
797 CThreadTransferSrcDstVectorDim,
798 CThreadTransferDstScalarPerVector>;
813 WeiDataType* p_wei_grid,
814 const OutDataType* p_out_grid,
815 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
816 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
817 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
818 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
819 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
820 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
821 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
822 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
823 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
824 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
825 InElementwiseOperation in_element_op,
826 WeiElementwiseOperation wei_element_op,
827 OutElementwiseOperation out_element_op,
859 conv_filter_dilations,
928 std::cout <<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
934 std::cout <<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
953 throw std::runtime_error(
954 "wrong! GridwiseGemm GridwiseGemmDl_bkm_bkn_mn_v1r3 has invalid setting");
960 auto launch_kernel = [&](
auto has_main_k_block_loop,
961 auto has_double_tail_k_block_loop) {
962 constexpr bool has_main_loop = has_main_k_block_loop.value;
963 constexpr bool has_double_loop = has_double_tail_k_block_loop.value;
973 ComputePtrOffsetOfStridedBatch<>,
995 const bool has_double_tail_k_block_loop =
998 if(has_main_k_block_loop && has_double_tail_k_block_loop)
1003 else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
1006 integral_constant<bool, false>{});
1008 else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
1011 integral_constant<bool, true>{});
1016 integral_constant<bool, false>{});
1023 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
1040 if constexpr(!((NDimSpatial == 1 &&
1043 (NDimSpatial == 2 &&
1046 (NDimSpatial == 3 &&
1053 if constexpr(ConvBackwardWeightSpecialization ==
1057 for(
int i = 0; i < NDimSpatial; i++)
1070 auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
1071 if(srcVectorLengths[
I2] != 1 || srcVectorLengths[
I3] != 1)
1075 if(K1 % srcVectorLengths[
I4] != 0 || K0PerBlock % srcVectorLengths[
I1] != 0)
1082 if(K % (srcVectorLengths[
I1] * srcVectorLengths[
I4]) != 0)
1090 auto srcLoadLenghts = BBlockTransferThreadSliceLengths_K0_N0_N1_K1{};
1091 auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{};
1092 if(srcVectorLengths[
I1] != 1 || srcVectorLengths[
I4] != 1)
1096 if(srcLoadLenghts[
I2] % srcVectorLengths[
I2] != 0 ||
1097 srcLoadLenghts[
I3] % srcVectorLengths[
I3] != 0)
1104 if(C % (srcVectorLengths[
I2] * srcVectorLengths[
I3]) != 0)
1111 if(!(arg.
Conv_C_ % CThreadTransferDstScalarPerVector == 0))
1113 std::cout <<
"Not surpport,because: arg.Conv_C_ % CThreadTransferDstScalarPerVector = "
1114 << arg.
Conv_C_ % CThreadTransferDstScalarPerVector << std::endl;
1130 WeiDataType* p_wei_grid,
1131 const OutDataType* p_out_grid,
1132 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1133 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1134 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1135 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1136 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1137 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1138 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
1139 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
1140 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
1141 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
1142 InElementwiseOperation in_element_op,
1143 WeiElementwiseOperation wei_element_op,
1144 OutElementwiseOperation out_element_op,
1150 a_g_n_c_wis_lengths,
1151 a_g_n_c_wis_strides,
1154 e_g_n_k_wos_lengths,
1155 e_g_n_k_wos_strides,
1156 conv_filter_strides,
1157 conv_filter_dilations,
1168 std::unique_ptr<BaseArgument>
1171 const void* p_out_grid,
1172 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1173 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1174 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1175 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1176 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1177 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1178 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
1179 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
1180 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
1181 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
1182 InElementwiseOperation in_element_op,
1183 WeiElementwiseOperation wei_element_op,
1184 OutElementwiseOperation out_element_op,
1187 return std::make_unique<Argument>(
static_cast<const InDataType*
>(p_in_grid),
1188 static_cast<WeiDataType*
>(p_wei_grid),
1189 static_cast<const OutDataType*
>(p_out_grid),
1190 a_g_n_c_wis_lengths,
1191 a_g_n_c_wis_strides,
1194 e_g_n_k_wos_lengths,
1195 e_g_n_k_wos_strides,
1196 conv_filter_strides,
1197 conv_filter_dilations,
1208 return std::make_unique<Invoker>(
Invoker{});
1213 auto str = std::stringstream();
1216 str <<
"DeviceGroupedConvBwdWeight_Dl"
1218 << BlockSize <<
", "
1219 << MPerBlock <<
", "
1220 << NPerBlock <<
", "
1221 << K0PerBlock <<
", "
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
constexpr bool is_NWGC_GKXC_NWGK()
Definition device_grouped_conv_utils.hpp:15
ConvolutionBackwardWeightSpecialization
Definition convolution_backward_weight_specialization.hpp:13
@ Filter1x1Stride1Pad0
Definition convolution_backward_weight_specialization.hpp:15
__host__ __device__ constexpr auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition matrix_padder.hpp:19
constexpr bool is_GNWC_GKXC_GNWK()
Definition device_grouped_conv_utils.hpp:23
__global__ void kernel_batched_gemm_dlops_bwd_weight(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const index_t batch_count, const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1, const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
Definition device_grouped_conv_bwd_weight_dl.hpp:40
constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
Definition device_grouped_conv_utils.hpp:88
constexpr bool is_NHWGC_GKYXC_NHWGK()
Definition device_grouped_conv_utils.hpp:40
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
Definition device_grouped_conv_utils.hpp:80
std::string getConvBackwardWeightSpecializationString(const ConvolutionBackwardWeightSpecialization &s)
Definition convolution_backward_weight_specialization.hpp:21
constexpr bool is_GNHWC_GKYXC_GNHWK()
Definition device_grouped_conv_utils.hpp:48
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_dl_v1r3.hpp:612
ck::GridwiseGemmDl_bkm_bkn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_B_K0_M_K1, BGridDesc_B_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeCBlockClusterAdaptor __host__ static __device__ constexpr auto MakeCBlockClusterAdaptor(const CGridDesc_M_N &c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
Definition gridwise_gemm_dl_v1r3.hpp:774
ck::GridwiseGemmDl_bkm_bkn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_B_K0_M_K1, BGridDesc_B_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeAGridDescriptor_B_K0_M0_M1_K1 __host__ static __device__ constexpr auto MakeAGridDescriptor_B_K0_M0_M1_K1(const AGridDesc_B_K0_M_K1 &a_grid_desc_b_k0_m_k1)
Definition gridwise_gemm_dl_v1r3.hpp:698
ck::GridwiseGemmDl_bkm_bkn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_B_K0_M_K1, BGridDesc_B_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateHasMainKBlockLoop __host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_v1r3.hpp:683
ck::GridwiseGemmDl_bkm_bkn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_B_K0_M_K1, BGridDesc_B_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11 __host__ static __device__ constexpr auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_v1r3.hpp:742
ck::GridwiseGemmDl_bkm_bkn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_B_K0_M_K1, BGridDesc_B_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AGridDesc_B_K0_M_K1 &a_grid_desc_b_k0_m_k1, const BGridDesc_B_K0_N_K1 &b_grid_desc_b_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_v1r3.hpp:648
ck::GridwiseGemmDl_bkm_bkn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_B_K0_M_K1, BGridDesc_B_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateHasDoubleTailKBlockLoop __host__ static __device__ constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_v1r3.hpp:690
ck::GridwiseGemmDl_bkm_bkn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_B_K0_M_K1, BGridDesc_B_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeBGridDescriptor_B_K0_N0_N1_K1 __host__ static __device__ constexpr auto MakeBGridDescriptor_B_K0_N0_N1_K1(const BGridDesc_B_K0_N_K1 &b_grid_desc_b_k0_n_k1)
Definition gridwise_gemm_dl_v1r3.hpp:720
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
Definition device_base.hpp:197
Definition device_grouped_conv_bwd_weight_dl.hpp:811
AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1_
Definition device_grouped_conv_bwd_weight_dl.hpp:893
InElementwiseOperation c_element_op_
Definition device_grouped_conv_bwd_weight_dl.hpp:906
std::array< ck::index_t, NDimSpatial+3 > filter_lengths_
Definition device_grouped_conv_bwd_weight_dl.hpp:913
const index_t Conv_G_
Definition device_grouped_conv_bwd_weight_dl.hpp:909
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition device_grouped_conv_bwd_weight_dl.hpp:901
const ADataType * p_a_grid_
Definition device_grouped_conv_bwd_weight_dl.hpp:885
const index_t Conv_C_
Definition device_grouped_conv_bwd_weight_dl.hpp:911
AGridDesc_B_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_
Definition device_grouped_conv_bwd_weight_dl.hpp:889
index_t k_batch_
Definition device_grouped_conv_bwd_weight_dl.hpp:918
const index_t Conv_K_
Definition device_grouped_conv_bwd_weight_dl.hpp:910
BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1_
Definition device_grouped_conv_bwd_weight_dl.hpp:894
Argument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k)
Definition device_grouped_conv_bwd_weight_dl.hpp:812
WeiElementwiseOperation b_element_op_
Definition device_grouped_conv_bwd_weight_dl.hpp:905
OutElementwiseOperation a_element_op_
Definition device_grouped_conv_bwd_weight_dl.hpp:904
CGridDesc_M_N c_grid_desc_m_n_
Definition device_grouped_conv_bwd_weight_dl.hpp:891
const std::array< ck::index_t, NDimSpatial > & input_right_pads_
Definition device_grouped_conv_bwd_weight_dl.hpp:917
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_grouped_conv_bwd_weight_dl.hpp:895
const std::array< ck::index_t, NDimSpatial > & input_left_pads_
Definition device_grouped_conv_bwd_weight_dl.hpp:916
const std::array< ck::index_t, NDimSpatial > & conv_filter_dilations_
Definition device_grouped_conv_bwd_weight_dl.hpp:915
const std::array< ck::index_t, NDimSpatial > & conv_filter_strides_
Definition device_grouped_conv_bwd_weight_dl.hpp:914
CDataType * p_c_grid_
Definition device_grouped_conv_bwd_weight_dl.hpp:887
Block2CTileMap block_2_ctile_map_
Definition device_grouped_conv_bwd_weight_dl.hpp:898
BGridDesc_B_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_
Definition device_grouped_conv_bwd_weight_dl.hpp:890
const BDataType * p_b_grid_
Definition device_grouped_conv_bwd_weight_dl.hpp:886
Definition device_grouped_conv_bwd_weight_dl.hpp:923
DeviceOp::Argument Argument
Definition device_grouped_conv_bwd_weight_dl.hpp:924
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_bwd_weight_dl.hpp:944
void ShowInfo(const Argument &arg)
Definition device_grouped_conv_bwd_weight_dl.hpp:926
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_bwd_weight_dl.hpp:1020
Definition device_grouped_conv_bwd_weight_dl.hpp:139
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_B_K0_M_K1
Definition device_grouped_conv_bwd_weight_dl.hpp:760
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_bwd_weight_dl.hpp:1033
static auto MakeArgument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k)
Definition device_grouped_conv_bwd_weight_dl.hpp:1129
GridwiseGemmDl_bkm_bkn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_B_K0_M_K1, BGridDesc_B_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition device_grouped_conv_bwd_weight_dl.hpp:764
static constexpr auto K1Number
Definition device_grouped_conv_bwd_weight_dl.hpp:162
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, const ck::index_t batch_k)
Definition device_grouped_conv_bwd_weight_dl.hpp:180
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_grouped_conv_bwd_weight_dl.hpp:762
static constexpr auto I0
Definition device_grouped_conv_bwd_weight_dl.hpp:153
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)) Block2CTileMap
Definition device_grouped_conv_bwd_weight_dl.hpp:807
WeiDataType CDataType
Definition device_grouped_conv_bwd_weight_dl.hpp:144
static constexpr auto I5
Definition device_grouped_conv_bwd_weight_dl.hpp:158
static constexpr auto BBlockLdsN1PerBlock
Definition device_grouped_conv_bwd_weight_dl.hpp:175
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition device_grouped_conv_bwd_weight_dl.hpp:805
InElementwiseOperation BElementwiseOperation
Definition device_grouped_conv_bwd_weight_dl.hpp:147
std::string GetTypeString() const override
Definition device_grouped_conv_bwd_weight_dl.hpp:1211
static auto GetABCGridDesc()
Definition device_grouped_conv_bwd_weight_dl.hpp:729
static constexpr auto I1
Definition device_grouped_conv_bwd_weight_dl.hpp:154
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_conv_bwd_weight_dl.hpp:1027
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_B_K0_N_K1
Definition device_grouped_conv_bwd_weight_dl.hpp:761
static constexpr auto I3
Definition device_grouped_conv_bwd_weight_dl.hpp:156
WeiElementwiseOperation CElementwiseOperation
Definition device_grouped_conv_bwd_weight_dl.hpp:148
static constexpr auto GemmK1Number
Definition device_grouped_conv_bwd_weight_dl.hpp:163
static constexpr auto ElePerBank
Definition device_grouped_conv_bwd_weight_dl.hpp:167
DeviceGroupedConvBwdWeight_Dl DeviceOp
Definition device_grouped_conv_bwd_weight_dl.hpp:140
static constexpr auto BBlockLdsN0PerBlock
Definition device_grouped_conv_bwd_weight_dl.hpp:176
static constexpr auto I2
Definition device_grouped_conv_bwd_weight_dl.hpp:155
InDataType ABDataType
Definition device_grouped_conv_bwd_weight_dl.hpp:151
static constexpr auto ABlockLdsM1PerBlock
Definition device_grouped_conv_bwd_weight_dl.hpp:170
InDataType BDataType
Definition device_grouped_conv_bwd_weight_dl.hpp:143
static constexpr auto BBlockLdsN1Padding
Definition device_grouped_conv_bwd_weight_dl.hpp:177
static constexpr auto ABlockLdsM1Padding
Definition device_grouped_conv_bwd_weight_dl.hpp:172
decltype(GridwiseGemm::MakeAGridDescriptor_B_K0_M0_M1_K1(AGridDesc_B_K0_M_K1{})) AGridDesc_B_K0_M0_M1_K1
Definition device_grouped_conv_bwd_weight_dl.hpp:801
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_bwd_weight_dl.hpp:1206
static constexpr auto BankLength
Definition device_grouped_conv_bwd_weight_dl.hpp:166
static auto MakeInvoker()
Definition device_grouped_conv_bwd_weight_dl.hpp:1166
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_grid, void *p_wei_grid, const void *p_out_grid, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k) override
Definition device_grouped_conv_bwd_weight_dl.hpp:1169
static constexpr auto spatial_offset
Definition device_grouped_conv_bwd_weight_dl.hpp:160
OutElementwiseOperation AElementwiseOperation
Definition device_grouped_conv_bwd_weight_dl.hpp:146
OutDataType ADataType
Definition device_grouped_conv_bwd_weight_dl.hpp:142
static constexpr auto ABlockLdsM0PerBlock
Definition device_grouped_conv_bwd_weight_dl.hpp:171
decltype(GetABCGridDesc< NDimSpatial >()) ABCGridDescs
Definition device_grouped_conv_bwd_weight_dl.hpp:758
static constexpr auto I4
Definition device_grouped_conv_bwd_weight_dl.hpp:157
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_bwd_weight_dl.hpp:1123
decltype(GridwiseGemm::MakeBGridDescriptor_B_K0_N0_N1_K1(BGridDesc_B_K0_N_K1{})) BGridDesc_B_K0_N0_N1_K1
Definition device_grouped_conv_bwd_weight_dl.hpp:803
Definition device_grouped_conv_bwd_weight.hpp:29