device_grouped_conv_bwd_weight_dl.hpp Source File

device_grouped_conv_bwd_weight_dl.hpp Source File#

Composable Kernel: device_grouped_conv_bwd_weight_dl.hpp Source File
device_grouped_conv_bwd_weight_dl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <numeric>
8#include <sstream>
9
21
22namespace ck {
23namespace tensor_operation {
24namespace device {
25
26template <typename GridwiseGemm,
27 typename FloatAB,
28 typename FloatC,
29 typename AGridDesc_B_K0_M0_M1_K1,
30 typename BGridDesc_B_K0_N0_N1_K1,
31 typename CGridDesc_M0_M10_M11_N0_N10_N11,
32 typename Block2CTileMap,
33 typename ComputePtrOffsetOfBatch,
34 bool HasMainKBlockLoop,
35 bool HasDoubleTailKBlockLoop>
36__global__ void
37#if CK_USE_LAUNCH_BOUNDS
39#endif
41 const FloatAB* __restrict__ p_a_grid,
42 const FloatAB* __restrict__ p_b_grid,
43 FloatC* __restrict__ p_c_grid,
44 const index_t batch_count,
45 const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1,
46 const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1,
47 const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
48 const Block2CTileMap block_2_ctile_map,
49 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
50{
51#if(defined(__gfx906__) || defined(__gfx103__) || defined(__gfx90a__) || defined(__gfx908__) || \
52 defined(__gfx94__) || defined(__gfx11__) || defined(__gfx12__))
53 const index_t num_blocks_per_batch =
54 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
55 const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
56
57 const long_index_t a_batch_offset =
58 amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
59 const long_index_t b_batch_offset =
60 amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
61 const long_index_t c_batch_offset =
62 amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
63
64 __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
65
66 GridwiseGemm::template Run<HasMainKBlockLoop, HasDoubleTailKBlockLoop>(
67 p_a_grid + a_batch_offset,
68 p_b_grid + b_batch_offset,
69 p_c_grid + c_batch_offset,
70 p_shared,
71 a_grid_desc_kbatch_k0_m0_m1_k1,
72 b_grid_desc_kbatch_k0_n0_n1_k1,
73 c_grid_desc_m0_m10_m11_n0_n10_n11,
74 block_2_ctile_map,
77#else
78 ignore = p_a_grid;
79 ignore = p_b_grid;
80 ignore = p_c_grid;
81 ignore = batch_count;
82 ignore = a_grid_desc_kbatch_k0_m0_m1_k1;
83 ignore = b_grid_desc_kbatch_k0_n0_n1_k1;
84 ignore = c_grid_desc_m0_m10_m11_n0_n10_n11;
85 ignore = block_2_ctile_map;
86 ignore = compute_ptr_offset_of_batch;
87#endif
88}
89
90template <ck::index_t NDimSpatial,
91 typename InLayout,
92 typename WeiLayout,
93 typename OutLayout,
94 typename InDataType,
95 typename WeiDataType,
96 typename OutDataType,
97 typename AccDataType,
98 typename InElementwiseOperation,
99 typename WeiElementwiseOperation,
100 typename OutElementwiseOperation,
101 ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization,
102 ck::index_t BlockSize,
103 ck::index_t MPerBlock,
104 ck::index_t NPerBlock,
105 ck::index_t K0PerBlock,
106 ck::index_t K1,
107 index_t M1PerThread,
108 index_t N1PerThread,
109 index_t KPerThread,
110 typename M1N1ThreadClusterM1Xs,
111 typename M1N1ThreadClusterN1Xs,
112 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
113 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
114 typename ABlockTransferThreadClusterArrangeOrder,
115 typename ABlockTransferSrcAccessOrder,
116 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
117 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
118 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
119 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
120 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
121 typename BBlockTransferThreadClusterArrangeOrder,
122 typename BBlockTransferSrcAccessOrder,
123 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
124 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
125 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
126 typename CThreadTransferSrcDstAccessOrder,
127 index_t CThreadTransferSrcDstVectorDim,
128 index_t CThreadTransferDstScalarPerVector>
130 InLayout,
131 WeiLayout,
132 OutLayout,
133 InDataType,
134 WeiDataType,
135 OutDataType,
136 InElementwiseOperation,
137 WeiElementwiseOperation,
138 OutElementwiseOperation>
139{
141
142 using ADataType = OutDataType;
143 using BDataType = InDataType;
144 using CDataType = WeiDataType;
145
146 using AElementwiseOperation = OutElementwiseOperation;
147 using BElementwiseOperation = InElementwiseOperation;
148 using CElementwiseOperation = WeiElementwiseOperation;
149
150 // TODO make A/B datatype different
151 using ABDataType = InDataType;
152
153 static constexpr auto I0 = Number<0>{};
154 static constexpr auto I1 = Number<1>{};
155 static constexpr auto I2 = Number<2>{};
156 static constexpr auto I3 = Number<3>{};
157 static constexpr auto I4 = Number<4>{};
158 static constexpr auto I5 = Number<5>{};
159
160 static constexpr auto spatial_offset = I3;
161
162 static constexpr auto K1Number = Number<K1>{};
163 static constexpr auto GemmK1Number = K1Number;
164
165 // Bytes per 32 lds bank: 32 * 4 bytes
166 static constexpr auto BankLength = 128;
167 static constexpr auto ElePerBank = BankLength / sizeof(ADataType);
168
169 // M1 & M0
170 static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1;
171 static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock;
172 static constexpr auto ABlockLdsM1Padding = 4;
173
174 // N1 & N0
175 static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1;
176 static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock;
177 static constexpr auto BBlockLdsN1Padding = 4;
178
179 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
181 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
182 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
183 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
184 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
185 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
186 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
187 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
188 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
189 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
190 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
191 const ck::index_t batch_k)
192 {
193 using namespace ck;
194
195 const index_t N = a_g_n_c_wis_lengths[I1];
196 const index_t K = b_g_k_c_xs_lengths[I1];
197 const index_t C = a_g_n_c_wis_lengths[I2];
198 const index_t Wi = a_g_n_c_wis_lengths[spatial_offset];
199 const index_t Wo = e_g_n_k_wos_lengths[spatial_offset];
200 const index_t X = b_g_k_c_xs_lengths[spatial_offset];
201 const index_t InLeftPadW = input_left_pads[I0];
202 const index_t InRightPadW = input_right_pads[I0];
203 const index_t ConvStrideW = conv_filter_strides[I0];
204 const index_t ConvDilationW = conv_filter_dilations[I0];
205
206 const auto InNStride = a_g_n_c_wis_strides[I1];
207 const auto InCStride = a_g_n_c_wis_strides[I2];
208 const auto InWStride = a_g_n_c_wis_strides[spatial_offset];
209 const auto WeiKStride = b_g_k_c_xs_strides[I1];
210 const auto WeiCStride = b_g_k_c_xs_strides[I2];
211 const auto OutKStride = e_g_n_k_wos_strides[I2];
212 const auto OutWStride = e_g_n_k_wos_strides[spatial_offset];
213
214 const index_t GemmKTotal = N * Wo;
215 const index_t GemmKBatch = batch_k;
216 const index_t GemmK0 =
217 math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
218 K0PerBlock;
219
220 if constexpr(ConvBackwardWeightSpecialization ==
222 {
223 // A: output tensor
224 const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
225 make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride));
226
227 const auto out_gemmkpad_gemmmpad_grid_desc =
229 out_gemmktotal_gemmm_grid_desc,
230 make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
232
233 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
234 out_gemmkpad_gemmmpad_grid_desc,
236 make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
237 make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
240
241 // B: input tensor
242 const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
243 make_tuple(N * Wi, C), make_tuple(InWStride, InCStride));
244
245 const auto in_gemmkpad_gemmnpad_grid_desc =
247 in_gemmktotal_gemmn_grid_desc,
248 make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
250
251 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
252 in_gemmkpad_gemmnpad_grid_desc,
254 make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
255 make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
258
259 // C: weights tensor
260 const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
261 make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride));
262
263 const auto wei_gemmmpad_gemmnpad_grid_desc =
265 make_tuple(MPerBlock, NPerBlock),
267
268 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
269 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
270 wei_gemmmpad_gemmnpad_grid_desc);
271 }
272 else
273 {
274 const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
275 make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride));
276 const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor(
277 make_tuple(N, Wi, C), make_tuple(InNStride, InWStride, InCStride));
278
279 // A: output tensor
280 const auto out_gemmkpad_gemmmpad_grid_desc =
282 out_gemmktotal_gemmm_grid_desc,
283 make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
285
286 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
287 out_gemmkpad_gemmmpad_grid_desc,
289 make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
290 make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
293
294 // B: input tensor
295 const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
296 in_n_wi_c_grid_desc,
298 make_pad_transform(Wi, InLeftPadW, InRightPadW),
302
303 const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
304 in_n_wip_c_grid_desc,
307 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
311
312 const auto in_gemmktotal_gemmn_grid_desc =
313 transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
318
319 const auto in_gemmkpad_gemmnpad_grid_desc =
321 in_gemmktotal_gemmn_grid_desc,
322 make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
324
325 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
326 in_gemmkpad_gemmnpad_grid_desc,
328 make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
329 make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
332
333 // C: weight tensor
334 const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
335 make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride));
336
337 const auto wei_gemmmpad_gemmnpad_grid_desc =
339 make_tuple(MPerBlock, NPerBlock),
341
342 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
343 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
344 wei_gemmmpad_gemmnpad_grid_desc);
345 }
346
347 } // function end
348 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
350 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
351 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
352 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
353 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
354 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
355 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
356 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
357 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
358 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
359 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
360 const ck::index_t batch_k)
361 {
362 using namespace ck;
363
364 const index_t N = a_g_n_c_wis_lengths[I1];
365 const index_t K = b_g_k_c_xs_lengths[I1];
366 const index_t C = a_g_n_c_wis_lengths[I2];
367 const index_t Hi = a_g_n_c_wis_lengths[spatial_offset];
368 const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I1];
369 const index_t Ho = e_g_n_k_wos_lengths[spatial_offset];
370 const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I1];
371 const index_t Y = b_g_k_c_xs_lengths[spatial_offset];
372 const index_t X = b_g_k_c_xs_lengths[spatial_offset + I1];
373
374 const index_t InLeftPadH = input_left_pads[I0];
375 const index_t InLeftPadW = input_left_pads[I1];
376 const index_t InRightPadH = input_right_pads[I0];
377 const index_t InRightPadW = input_right_pads[I1];
378 const index_t ConvStrideH = conv_filter_strides[I0];
379 const index_t ConvStrideW = conv_filter_strides[I1];
380 const index_t ConvDilationH = conv_filter_dilations[I0];
381 const index_t ConvDilationW = conv_filter_dilations[I1];
382
383 const auto InNStride = a_g_n_c_wis_strides[I1];
384 const auto InCStride = a_g_n_c_wis_strides[I2];
385 const auto InHStride = a_g_n_c_wis_strides[spatial_offset];
386 const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I1];
387 const auto WeiKStride = b_g_k_c_xs_strides[I1];
388 const auto WeiCStride = b_g_k_c_xs_strides[I2];
389 const auto OutKStride = e_g_n_k_wos_strides[I2];
390 const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I1];
391
392 const index_t GemmKTotal = N * Ho * Wo;
393 const index_t GemmKBatch = batch_k;
394 const index_t GemmK0 =
395 math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
396 K0PerBlock;
397
398 if constexpr(ConvBackwardWeightSpecialization ==
400 {
401 // A: output tensor
402 const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
403 make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
404
405 const auto out_gemmkpad_gemmmpad_grid_desc =
407 out_gemmktotal_gemmm_grid_desc,
408 make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
410
411 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
412 out_gemmkpad_gemmmpad_grid_desc,
414 make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
415 make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
418
419 // B: input tensor
420 const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
421 make_tuple(N * Hi * Wi, C), make_tuple(InWStride, InCStride));
422
423 const auto in_gemmkpad_gemmnpad_grid_desc =
425 in_gemmktotal_gemmn_grid_desc,
426 make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
428
429 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
430 in_gemmkpad_gemmnpad_grid_desc,
432 make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
433 make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
436
437 // C: weight tensor
438 const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
439 make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride));
440
441 const auto wei_gemmmpad_gemmnpad_grid_desc =
443 make_tuple(MPerBlock, NPerBlock),
445
446 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
447 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
448 wei_gemmmpad_gemmnpad_grid_desc);
449 }
450 else
451 {
452 const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
453 make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
454 const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
455 make_tuple(N, Hi, Wi, C), make_tuple(InNStride, InHStride, InWStride, InCStride));
456
457 // A: output tensor
458 const auto out_gemmkpad_gemmmpad_grid_desc =
460 out_gemmktotal_gemmm_grid_desc,
461 make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
463
464 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
465 out_gemmkpad_gemmmpad_grid_desc,
467 make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
468 make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
471
472 // B: input tensor
473 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
474 in_n_hi_wi_c_grid_desc,
476 make_pad_transform(Hi, InLeftPadH, InRightPadH),
477 make_pad_transform(Wi, InLeftPadW, InRightPadW),
481
482 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
483 in_n_hip_wip_c_grid_desc,
486 make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
487 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
491
492 const auto in_gemmktotal_gemmn_grid_desc =
493 transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
495 make_merge_transform(make_tuple(N, Ho, Wo))),
498
499 const auto in_gemmkpad_gemmnpad_grid_desc =
501 in_gemmktotal_gemmn_grid_desc,
502 make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
504
505 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
506 in_gemmkpad_gemmnpad_grid_desc,
508 make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
509 make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
512
513 // C: weight tensor
514 const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
515 make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride));
516
517 const auto wei_gemmmpad_gemmnpad_grid_desc =
519 make_tuple(MPerBlock, NPerBlock),
521
522 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
523 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
524 wei_gemmmpad_gemmnpad_grid_desc);
525 }
526
527 } // function end
528
529 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
531 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
532 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
533 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
534 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
535 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
536 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
537 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
538 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
539 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
540 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
541 const ck::index_t batch_k)
542 {
543 using namespace ck;
544
545 const index_t N = a_g_n_c_wis_lengths[I1];
546 const index_t K = b_g_k_c_xs_lengths[I1];
547 const index_t C = a_g_n_c_wis_lengths[I2];
548 const index_t Di = a_g_n_c_wis_lengths[spatial_offset + I0];
549 const index_t Hi = a_g_n_c_wis_lengths[spatial_offset + I1];
550 const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I2];
551 const index_t Do = e_g_n_k_wos_lengths[spatial_offset + I0];
552 const index_t Ho = e_g_n_k_wos_lengths[spatial_offset + I1];
553 const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I2];
554 const index_t Z = b_g_k_c_xs_lengths[spatial_offset + I0];
555 const index_t Y = b_g_k_c_xs_lengths[spatial_offset + I1];
556 const index_t X = b_g_k_c_xs_lengths[spatial_offset + I2];
557
558 const index_t InLeftPadD = input_left_pads[I0];
559 const index_t InLeftPadH = input_left_pads[I1];
560 const index_t InLeftPadW = input_left_pads[I2];
561 const index_t InRightPadD = input_right_pads[I0];
562 const index_t InRightPadH = input_right_pads[I1];
563 const index_t InRightPadW = input_right_pads[I2];
564 const index_t ConvStrideD = conv_filter_strides[I0];
565 const index_t ConvStrideH = conv_filter_strides[I1];
566 const index_t ConvStrideW = conv_filter_strides[I2];
567 const index_t ConvDilationD = conv_filter_dilations[I0];
568 const index_t ConvDilationH = conv_filter_dilations[I1];
569 const index_t ConvDilationW = conv_filter_dilations[I2];
570
571 const auto InNStride = a_g_n_c_wis_strides[I1];
572 const auto InCStride = a_g_n_c_wis_strides[I2];
573 const auto InDStride = a_g_n_c_wis_strides[spatial_offset];
574 const auto InHStride = a_g_n_c_wis_strides[spatial_offset + I1];
575 const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I2];
576 const auto WeiKStride = b_g_k_c_xs_strides[I1];
577 const auto WeiCStride = b_g_k_c_xs_strides[I2];
578 const auto OutKStride = e_g_n_k_wos_strides[I2];
579 const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I2];
580
581 const index_t GemmKTotal = N * Do * Ho * Wo;
582 const index_t GemmKBatch = batch_k;
583 const index_t GemmK0 =
584 math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
585 K0PerBlock;
586
587 if constexpr(ConvBackwardWeightSpecialization ==
589 {
590 // A: output tensor
591 const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
592 make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
593
594 const auto out_gemmkpad_gemmmpad_grid_desc =
596 out_gemmktotal_gemmm_grid_desc,
597 make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
599
600 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
601 out_gemmkpad_gemmmpad_grid_desc,
603 make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
604 make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
607
608 // B: input tensor
609 const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
610 make_tuple(N * Di * Hi * Wi, C), make_tuple(InWStride, InCStride));
611
612 const auto in_gemmkpad_gemmnpad_grid_desc =
614 in_gemmktotal_gemmn_grid_desc,
615 make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
617
618 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
619 in_gemmkpad_gemmnpad_grid_desc,
621 make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
622 make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
625
626 // C: weight tensor
627 const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
628 make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride));
629
630 const auto wei_gemmmpad_gemmnpad_grid_desc =
632 make_tuple(MPerBlock, NPerBlock),
634
635 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
636 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
637 wei_gemmmpad_gemmnpad_grid_desc);
638 }
639 else
640 {
641 const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
642 make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
643 const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
644 make_tuple(N, Di, Hi, Wi, C),
645 make_tuple(InNStride, InDStride, InHStride, InWStride, InCStride));
646
647 // A: output tensor
648 const auto out_gemmkpad_gemmmpad_grid_desc =
650 out_gemmktotal_gemmm_grid_desc,
651 make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
653
654 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
655 out_gemmkpad_gemmmpad_grid_desc,
657 make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
658 make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
661
662 // B: input tensor
663 const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
664 in_n_di_hi_wi_c_grid_desc,
666 make_pad_transform(Di, InLeftPadD, InRightPadD),
667 make_pad_transform(Hi, InLeftPadH, InRightPadH),
668 make_pad_transform(Wi, InLeftPadW, InRightPadW),
674
675 const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
676 in_n_dip_hip_wip_c_grid_desc,
679 make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
680 make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
681 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
689 Sequence<7>{}));
690
691 const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
692 in_n_z_do_y_ho_x_wo_c_grid_desc,
694 make_merge_transform(make_tuple(N, Do, Ho, Wo))),
697
698 const auto in_gemmkpad_gemmnpad_grid_desc =
700 in_gemmktotal_gemmn_grid_desc,
701 make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
703
704 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
705 in_gemmkpad_gemmnpad_grid_desc,
707 make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
708 make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
711
712 // C: weight tensor
713 const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
714 make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride));
715
716 const auto wei_gemmmpad_gemmnpad_grid_desc =
718 make_tuple(MPerBlock, NPerBlock),
720
721 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
722 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
723 wei_gemmmpad_gemmnpad_grid_desc);
724 }
725
726 } // function end
727
728 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
729 static auto GetABCGridDesc()
730 {
732 {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1);
733 }
734
735 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
736 static auto GetABCGridDesc()
737 {
739 {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1);
740 }
741
742 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
743 static auto GetABCGridDesc()
744 {
746 {1, 1, 1},
747 {1, 1, 1},
748 {1, 1, 1},
749 {1, 1, 1},
750 {1, 1, 1},
751 {1, 1, 1},
752 {1, 1, 1},
753 {1, 1, 1},
754 {1, 1, 1},
755 1);
756 }
757
759
763
766 ADataType,
767 AccDataType,
768 CDataType,
773 MPerBlock,
774 NPerBlock,
775 K0PerBlock,
776 K1,
777 M1PerThread,
778 N1PerThread,
779 KPerThread,
780 M1N1ThreadClusterM1Xs,
781 M1N1ThreadClusterN1Xs,
782 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
783 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
784 ABlockTransferThreadClusterArrangeOrder,
785 ABlockTransferSrcAccessOrder,
786 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
787 ABlockTransferSrcVectorTensorContiguousDimOrder,
788 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
789 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
790 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
791 BBlockTransferThreadClusterArrangeOrder,
792 BBlockTransferSrcAccessOrder,
793 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
794 BBlockTransferSrcVectorTensorContiguousDimOrder,
795 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
796 CThreadTransferSrcDstAccessOrder,
797 CThreadTransferSrcDstVectorDim,
798 CThreadTransferDstScalarPerVector>;
799
800 // Argument
809
810 struct Argument : public BaseArgument
811 {
812 Argument(const InDataType* p_in_grid,
813 WeiDataType* p_wei_grid,
814 const OutDataType* p_out_grid,
815 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
816 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
817 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
818 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
819 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
820 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
821 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
822 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
823 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
824 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
825 InElementwiseOperation in_element_op,
826 WeiElementwiseOperation wei_element_op,
827 OutElementwiseOperation out_element_op,
828 ck::index_t split_k)
829 : p_a_grid_{p_out_grid},
830 p_b_grid_{p_in_grid},
831 p_c_grid_{p_wei_grid},
837 a_element_op_{out_element_op},
838 b_element_op_{wei_element_op},
839 c_element_op_{in_element_op},
840 Conv_G_{a_g_n_c_wis_lengths[I0]},
841 Conv_K_{b_g_k_c_xs_lengths[I1]},
842 Conv_C_{a_g_n_c_wis_lengths[I2]},
843 filter_lengths_{b_g_k_c_xs_lengths},
844 conv_filter_strides_{conv_filter_strides},
845 conv_filter_dilations_{conv_filter_dilations},
846 input_left_pads_{input_left_pads},
847 input_right_pads_{input_right_pads},
848 k_batch_{split_k}
849 {
850 const auto descs =
852 a_g_n_c_wis_lengths, // input
853 a_g_n_c_wis_strides,
854 b_g_k_c_xs_lengths, // weight
855 b_g_k_c_xs_strides,
856 e_g_n_k_wos_lengths, // output
857 e_g_n_k_wos_strides,
858 conv_filter_strides,
859 conv_filter_dilations,
860 input_left_pads,
861 input_right_pads,
862 k_batch_);
863
866 c_grid_desc_m_n_ = descs[I2];
867
874 ck::index_t M01 = 1;
875 ck::index_t N01 = 1;
878
879 // A/B/C Batch Stride
880 compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[I0];
881 compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[I0];
882 compute_ptr_offset_of_batch_.BatchStrideC_ = b_g_k_c_xs_strides[I0];
883 }
884
888
892
896
897 // DefaultBlock2CTileMap block_2_ctile_map_;
899
900 // for computing batch offset
901 ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
902
903 // element-wise op
904 OutElementwiseOperation a_element_op_;
905 WeiElementwiseOperation b_element_op_;
906 InElementwiseOperation c_element_op_;
907
908 // for checking IsSupportedArgument()
912
913 std::array<ck::index_t, NDimSpatial + 3> filter_lengths_;
914 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
915 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
916 const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
917 const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
919 };
920
921 // Invoker
922 struct Invoker : public BaseInvoker
923 {
925
926 void ShowInfo(const Argument& arg)
927 {
928 std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
929 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
930 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
931 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
932 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
933
934 std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
935 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
936 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
937 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
938 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
939
940 std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
941 << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
942 }
943
944 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
945 {
946
947 ShowInfo(arg);
948
951 arg.c_grid_desc_m_n_))
952 {
953 throw std::runtime_error(
954 "wrong! GridwiseGemm GridwiseGemmDl_bkm_bkn_mn_v1r3 has invalid setting");
955 }
956
957 const index_t grid_size =
958 arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_;
959
960 auto launch_kernel = [&](auto has_main_k_block_loop,
961 auto has_double_tail_k_block_loop) {
962 constexpr bool has_main_loop = has_main_k_block_loop.value;
963 constexpr bool has_double_loop = has_double_tail_k_block_loop.value;
964
965 const auto kernel = kernel_batched_gemm_dlops_bwd_weight<
966 GridwiseGemm,
967 ADataType, // TODO: distiguish A/B datatype
968 CDataType,
973 ComputePtrOffsetOfStridedBatch<>,
974 has_main_loop,
975 has_double_loop>;
976
977 return launch_and_time_kernel(stream_config,
978 kernel,
979 dim3(grid_size),
980 dim3(BlockSize),
981 0,
982 arg.p_a_grid_,
983 arg.p_b_grid_,
984 arg.p_c_grid_,
985 arg.Conv_G_,
991 };
992
993 const auto K0 = arg.a_grid_desc_kbatch_k0_m0_m1_k1_.GetLength(I1);
994 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
995 const bool has_double_tail_k_block_loop =
997
998 if(has_main_k_block_loop && has_double_tail_k_block_loop)
999 {
1000 return launch_kernel(integral_constant<bool, true>{},
1002 }
1003 else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
1004 {
1005 return launch_kernel(integral_constant<bool, true>{},
1006 integral_constant<bool, false>{});
1007 }
1008 else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
1009 {
1010 return launch_kernel(integral_constant<bool, false>{},
1011 integral_constant<bool, true>{});
1012 }
1013 else
1014 {
1015 return launch_kernel(integral_constant<bool, false>{},
1016 integral_constant<bool, false>{});
1017 }
1018 }
1019
1020 float Run(const BaseArgument* p_arg,
1021 const StreamConfig& stream_config = StreamConfig{}) override
1022 {
1023 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
1024 }
1025 };
1026
1027 static constexpr bool IsValidCompilationParameter()
1028 {
1029 // TODO: properly implement this check
1030 return true;
1031 }
1032
1033 static bool IsSupportedArgument(const Argument& arg)
1034 {
1035
1036 // DL version only supports split_k equal to 1
1037 if(arg.k_batch_ != 1)
1038 return false;
1039
1040 if constexpr(!((NDimSpatial == 1 &&
1043 (NDimSpatial == 2 &&
1046 (NDimSpatial == 3 &&
1049 {
1050 return false;
1051 }
1052
1053 if constexpr(ConvBackwardWeightSpecialization ==
1055 {
1056 // check if it's 1x1, stride=1 pad = 0 conv
1057 for(int i = 0; i < NDimSpatial; i++)
1058 {
1059 if(!(arg.filter_lengths_[spatial_offset + i] == 1 &&
1060 arg.conv_filter_strides_[i] == 1 && arg.input_left_pads_[i] == 0 &&
1061 arg.input_right_pads_[i] == 0))
1062 {
1063 return false;
1064 }
1065 }
1066 }
1067
1068 // matrix A
1069 {
1070 auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
1071 if(srcVectorLengths[I2] != 1 || srcVectorLengths[I3] != 1)
1072 {
1073 return false;
1074 }
1075 if(K1 % srcVectorLengths[I4] != 0 || K0PerBlock % srcVectorLengths[I1] != 0)
1076 {
1077 return false;
1078 }
1079
1080 const index_t K = arg.Conv_K_;
1081
1082 if(K % (srcVectorLengths[I1] * srcVectorLengths[I4]) != 0)
1083 {
1084 return false;
1085 }
1086 }
1087
1088 // matrix B
1089 {
1090 auto srcLoadLenghts = BBlockTransferThreadSliceLengths_K0_N0_N1_K1{};
1091 auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{};
1092 if(srcVectorLengths[I1] != 1 || srcVectorLengths[I4] != 1)
1093 {
1094 return false;
1095 }
1096 if(srcLoadLenghts[I2] % srcVectorLengths[I2] != 0 ||
1097 srcLoadLenghts[I3] % srcVectorLengths[I3] != 0)
1098 {
1099 return false;
1100 }
1101
1102 const index_t C = arg.Conv_K_;
1103
1104 if(C % (srcVectorLengths[I2] * srcVectorLengths[I3]) != 0)
1105 {
1106 return false;
1107 }
1108 }
1109
1110 // vector store C matrix into global memory
1111 if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0))
1112 {
1113 std::cout << "Not surpport,because: arg.Conv_C_ % CThreadTransferDstScalarPerVector = "
1114 << arg.Conv_C_ % CThreadTransferDstScalarPerVector << std::endl;
1115 return false;
1116 }
1117
1118 // Gridwise GEMM size
1121 }
1122
1123 bool IsSupportedArgument(const BaseArgument* p_arg) override
1124 {
1125 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1126 }
1127
1128 static auto
1129 MakeArgument(const InDataType* p_in_grid,
1130 WeiDataType* p_wei_grid,
1131 const OutDataType* p_out_grid,
1132 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
1133 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1134 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
1135 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1136 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
1137 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1138 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
1139 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
1140 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
1141 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
1142 InElementwiseOperation in_element_op,
1143 WeiElementwiseOperation wei_element_op,
1144 OutElementwiseOperation out_element_op,
1145 ck::index_t split_k)
1146 {
1147 return Argument{p_in_grid,
1148 p_wei_grid,
1149 p_out_grid,
1150 a_g_n_c_wis_lengths, // input
1151 a_g_n_c_wis_strides,
1152 b_g_k_c_xs_lengths, // weight
1153 b_g_k_c_xs_strides,
1154 e_g_n_k_wos_lengths, // output
1155 e_g_n_k_wos_strides,
1156 conv_filter_strides,
1157 conv_filter_dilations,
1158 input_left_pads,
1159 input_right_pads,
1160 in_element_op,
1161 wei_element_op,
1162 out_element_op,
1163 split_k};
1164 }
1165
1166 static auto MakeInvoker() { return Invoker{}; }
1167
1168 std::unique_ptr<BaseArgument>
1169 MakeArgumentPointer(const void* p_in_grid,
1170 void* p_wei_grid,
1171 const void* p_out_grid,
1172 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
1173 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1174 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
1175 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1176 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
1177 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1178 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
1179 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
1180 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
1181 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
1182 InElementwiseOperation in_element_op,
1183 WeiElementwiseOperation wei_element_op,
1184 OutElementwiseOperation out_element_op,
1185 ck::index_t split_k) override
1186 {
1187 return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
1188 static_cast<WeiDataType*>(p_wei_grid),
1189 static_cast<const OutDataType*>(p_out_grid),
1190 a_g_n_c_wis_lengths, // input
1191 a_g_n_c_wis_strides,
1192 b_g_k_c_xs_lengths, // weight
1193 b_g_k_c_xs_strides,
1194 e_g_n_k_wos_lengths, // output
1195 e_g_n_k_wos_strides,
1196 conv_filter_strides,
1197 conv_filter_dilations,
1198 input_left_pads,
1199 input_right_pads,
1200 in_element_op,
1201 wei_element_op,
1202 out_element_op,
1203 split_k);
1204 }
1205
1206 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1207 {
1208 return std::make_unique<Invoker>(Invoker{});
1209 }
1210
1211 std::string GetTypeString() const override
1212 {
1213 auto str = std::stringstream();
1214
1215 // clang-format off
1216 str << "DeviceGroupedConvBwdWeight_Dl"
1217 << "<"
1218 << BlockSize << ", "
1219 << MPerBlock << ", "
1220 << NPerBlock << ", "
1221 << K0PerBlock << ", "
1222 << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
1223 << K1
1224 << ">";
1225 // clang-format on
1226
1227 return str.str();
1228 }
1229};
1230
1231} // namespace device
1232} // namespace tensor_operation
1233} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
constexpr bool is_NWGC_GKXC_NWGK()
Definition device_grouped_conv_utils.hpp:15
ConvolutionBackwardWeightSpecialization
Definition convolution_backward_weight_specialization.hpp:13
@ Filter1x1Stride1Pad0
Definition convolution_backward_weight_specialization.hpp:15
__host__ __device__ constexpr auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition matrix_padder.hpp:19
constexpr bool is_GNWC_GKXC_GNWK()
Definition device_grouped_conv_utils.hpp:23
__global__ void kernel_batched_gemm_dlops_bwd_weight(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const index_t batch_count, const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1, const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
Definition device_grouped_conv_bwd_weight_dl.hpp:40
constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
Definition device_grouped_conv_utils.hpp:88
constexpr bool is_NHWGC_GKYXC_NHWGK()
Definition device_grouped_conv_utils.hpp:40
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
Definition device_grouped_conv_utils.hpp:80
std::string getConvBackwardWeightSpecializationString(const ConvolutionBackwardWeightSpecialization &s)
Definition convolution_backward_weight_specialization.hpp:21
constexpr bool is_GNHWC_GKYXC_GNHWK()
Definition device_grouped_conv_utils.hpp:48
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_dl_v1r3.hpp:612
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
Definition device_base.hpp:197
Definition device_grouped_conv_bwd_weight_dl.hpp:811
AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1_
Definition device_grouped_conv_bwd_weight_dl.hpp:893
InElementwiseOperation c_element_op_
Definition device_grouped_conv_bwd_weight_dl.hpp:906
std::array< ck::index_t, NDimSpatial+3 > filter_lengths_
Definition device_grouped_conv_bwd_weight_dl.hpp:913
const index_t Conv_G_
Definition device_grouped_conv_bwd_weight_dl.hpp:909
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition device_grouped_conv_bwd_weight_dl.hpp:901
const ADataType * p_a_grid_
Definition device_grouped_conv_bwd_weight_dl.hpp:885
const index_t Conv_C_
Definition device_grouped_conv_bwd_weight_dl.hpp:911
AGridDesc_B_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_
Definition device_grouped_conv_bwd_weight_dl.hpp:889
index_t k_batch_
Definition device_grouped_conv_bwd_weight_dl.hpp:918
const index_t Conv_K_
Definition device_grouped_conv_bwd_weight_dl.hpp:910
BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1_
Definition device_grouped_conv_bwd_weight_dl.hpp:894
Argument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k)
Definition device_grouped_conv_bwd_weight_dl.hpp:812
WeiElementwiseOperation b_element_op_
Definition device_grouped_conv_bwd_weight_dl.hpp:905
OutElementwiseOperation a_element_op_
Definition device_grouped_conv_bwd_weight_dl.hpp:904
CGridDesc_M_N c_grid_desc_m_n_
Definition device_grouped_conv_bwd_weight_dl.hpp:891
const std::array< ck::index_t, NDimSpatial > & input_right_pads_
Definition device_grouped_conv_bwd_weight_dl.hpp:917
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_grouped_conv_bwd_weight_dl.hpp:895
const std::array< ck::index_t, NDimSpatial > & input_left_pads_
Definition device_grouped_conv_bwd_weight_dl.hpp:916
const std::array< ck::index_t, NDimSpatial > & conv_filter_dilations_
Definition device_grouped_conv_bwd_weight_dl.hpp:915
const std::array< ck::index_t, NDimSpatial > & conv_filter_strides_
Definition device_grouped_conv_bwd_weight_dl.hpp:914
CDataType * p_c_grid_
Definition device_grouped_conv_bwd_weight_dl.hpp:887
Block2CTileMap block_2_ctile_map_
Definition device_grouped_conv_bwd_weight_dl.hpp:898
BGridDesc_B_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_
Definition device_grouped_conv_bwd_weight_dl.hpp:890
const BDataType * p_b_grid_
Definition device_grouped_conv_bwd_weight_dl.hpp:886
Definition device_grouped_conv_bwd_weight_dl.hpp:923
DeviceOp::Argument Argument
Definition device_grouped_conv_bwd_weight_dl.hpp:924
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_bwd_weight_dl.hpp:944
void ShowInfo(const Argument &arg)
Definition device_grouped_conv_bwd_weight_dl.hpp:926
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_bwd_weight_dl.hpp:1020
Definition device_grouped_conv_bwd_weight_dl.hpp:139
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_B_K0_M_K1
Definition device_grouped_conv_bwd_weight_dl.hpp:760
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_bwd_weight_dl.hpp:1033
static auto MakeArgument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k)
Definition device_grouped_conv_bwd_weight_dl.hpp:1129
GridwiseGemmDl_bkm_bkn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_B_K0_M_K1, BGridDesc_B_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition device_grouped_conv_bwd_weight_dl.hpp:764
static constexpr auto K1Number
Definition device_grouped_conv_bwd_weight_dl.hpp:162
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, const ck::index_t batch_k)
Definition device_grouped_conv_bwd_weight_dl.hpp:180
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_grouped_conv_bwd_weight_dl.hpp:762
static constexpr auto I0
Definition device_grouped_conv_bwd_weight_dl.hpp:153
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)) Block2CTileMap
Definition device_grouped_conv_bwd_weight_dl.hpp:807
WeiDataType CDataType
Definition device_grouped_conv_bwd_weight_dl.hpp:144
static constexpr auto I5
Definition device_grouped_conv_bwd_weight_dl.hpp:158
static constexpr auto BBlockLdsN1PerBlock
Definition device_grouped_conv_bwd_weight_dl.hpp:175
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition device_grouped_conv_bwd_weight_dl.hpp:805
InElementwiseOperation BElementwiseOperation
Definition device_grouped_conv_bwd_weight_dl.hpp:147
std::string GetTypeString() const override
Definition device_grouped_conv_bwd_weight_dl.hpp:1211
static auto GetABCGridDesc()
Definition device_grouped_conv_bwd_weight_dl.hpp:729
static constexpr auto I1
Definition device_grouped_conv_bwd_weight_dl.hpp:154
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_conv_bwd_weight_dl.hpp:1027
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_B_K0_N_K1
Definition device_grouped_conv_bwd_weight_dl.hpp:761
static constexpr auto I3
Definition device_grouped_conv_bwd_weight_dl.hpp:156
WeiElementwiseOperation CElementwiseOperation
Definition device_grouped_conv_bwd_weight_dl.hpp:148
static constexpr auto GemmK1Number
Definition device_grouped_conv_bwd_weight_dl.hpp:163
static constexpr auto ElePerBank
Definition device_grouped_conv_bwd_weight_dl.hpp:167
DeviceGroupedConvBwdWeight_Dl DeviceOp
Definition device_grouped_conv_bwd_weight_dl.hpp:140
static constexpr auto BBlockLdsN0PerBlock
Definition device_grouped_conv_bwd_weight_dl.hpp:176
static constexpr auto I2
Definition device_grouped_conv_bwd_weight_dl.hpp:155
InDataType ABDataType
Definition device_grouped_conv_bwd_weight_dl.hpp:151
static constexpr auto ABlockLdsM1PerBlock
Definition device_grouped_conv_bwd_weight_dl.hpp:170
InDataType BDataType
Definition device_grouped_conv_bwd_weight_dl.hpp:143
static constexpr auto BBlockLdsN1Padding
Definition device_grouped_conv_bwd_weight_dl.hpp:177
static constexpr auto ABlockLdsM1Padding
Definition device_grouped_conv_bwd_weight_dl.hpp:172
decltype(GridwiseGemm::MakeAGridDescriptor_B_K0_M0_M1_K1(AGridDesc_B_K0_M_K1{})) AGridDesc_B_K0_M0_M1_K1
Definition device_grouped_conv_bwd_weight_dl.hpp:801
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_bwd_weight_dl.hpp:1206
static constexpr auto BankLength
Definition device_grouped_conv_bwd_weight_dl.hpp:166
static auto MakeInvoker()
Definition device_grouped_conv_bwd_weight_dl.hpp:1166
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_grid, void *p_wei_grid, const void *p_out_grid, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k) override
Definition device_grouped_conv_bwd_weight_dl.hpp:1169
static constexpr auto spatial_offset
Definition device_grouped_conv_bwd_weight_dl.hpp:160
OutElementwiseOperation AElementwiseOperation
Definition device_grouped_conv_bwd_weight_dl.hpp:146
OutDataType ADataType
Definition device_grouped_conv_bwd_weight_dl.hpp:142
static constexpr auto ABlockLdsM0PerBlock
Definition device_grouped_conv_bwd_weight_dl.hpp:171
decltype(GetABCGridDesc< NDimSpatial >()) ABCGridDescs
Definition device_grouped_conv_bwd_weight_dl.hpp:758
static constexpr auto I4
Definition device_grouped_conv_bwd_weight_dl.hpp:157
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_bwd_weight_dl.hpp:1123
decltype(GridwiseGemm::MakeBGridDescriptor_B_K0_N0_N1_K1(BGridDesc_B_K0_N_K1{})) BGridDesc_B_K0_N0_N1_K1
Definition device_grouped_conv_bwd_weight_dl.hpp:803
Definition device_grouped_conv_bwd_weight.hpp:29