device_batched_contraction_multiple_d_xdl_cshuffle.hpp Source File

device_batched_contraction_multiple_d_xdl_cshuffle.hpp Source File#

Composable Kernel: device_batched_contraction_multiple_d_xdl_cshuffle.hpp Source File
device_batched_contraction_multiple_d_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
20
21namespace ck {
22
23template <typename GridwiseGemm,
24 typename FloatAB,
25 typename FloatDsPointer,
26 typename FloatE,
27 typename AElementwiseOperation,
28 typename BElementwiseOperation,
29 typename CDEElementwiseOperation,
30 typename AGridDesc_AK0_M_AK1,
31 typename BGridDesc_BK0_N_BK1,
32 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
33 typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
34 typename ComputePtrOffsetOfBatch,
35 typename Block2ETileMap,
36 bool HasMainKBlockLoop>
37__global__ void
38#if CK_USE_LAUNCH_BOUNDS
40#endif
42 const FloatAB* __restrict__ p_a_grid,
43 const FloatAB* __restrict__ p_b_grid,
44 FloatDsPointer p_ds_grid,
45 FloatE* __restrict__ p_e_grid,
46 const index_t batch_count,
47 const AElementwiseOperation a_element_op,
48 const BElementwiseOperation b_element_op,
49 const CDEElementwiseOperation cde_element_op,
50 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
51 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
52 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
53 ds_grid_desc_mblock_mperblock_nblock_nperblock,
54 const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
55 e_grid_desc_mblock_mperblock_nblock_nperblock,
56 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
57 const Block2ETileMap block_2_etile_map)
58{
59#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
60 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
61 {
62 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
63
64 const index_t num_blocks_per_batch =
65 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
66 const index_t g_idx =
67 __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
68
69 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
70 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
71 const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
72 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
73 const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
74 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
75
76 const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
77
78 FloatDsPointer p_ds_grid_grp;
79
80 static constexpr index_t NumDTensor =
81 DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
82
84 [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
85
86 GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
87 p_a_grid + a_batch_offset,
88 p_b_grid + b_batch_offset,
89 p_ds_grid_grp,
90 p_e_grid + e_batch_offset,
91 p_shared,
92 a_element_op,
93 b_element_op,
94 cde_element_op,
95 a_grid_desc_ak0_m_ak1,
96 b_grid_desc_bk0_n_bk1,
97 ds_grid_desc_mblock_mperblock_nblock_nperblock,
98 e_grid_desc_mblock_mperblock_nblock_nperblock,
99 block_2_etile_map);
100 }
101#else
102 ignore = p_a_grid;
103 ignore = p_b_grid;
104 ignore = p_ds_grid;
105 ignore = p_e_grid;
106 ignore = batch_count;
107 ignore = a_element_op;
108 ignore = b_element_op;
109 ignore = cde_element_op;
110 ignore = a_grid_desc_ak0_m_ak1;
111 ignore = b_grid_desc_bk0_n_bk1;
112 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
113 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
114 ignore = block_2_etile_map;
115 ignore = compute_ptr_offset_of_batch;
116#endif
117}
118
119} // namespace ck
120
121namespace ck {
122namespace tensor_operation {
123namespace device {
124
125// Tensor Contraction:
126// input : A
127// input : B
128// input : D0, D1, ...
129// output : E
130// C = a_op(A) * b_op(B)
131// E = cde_op(C, D0, D1, ...)
132// Assume:
133// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
134// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
135// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
136// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
137
138// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner
139// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and
140// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted
141// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into
142// TensorSpecialization::Default with NumDimG/M/N/K = 1
143//
144// Detail- Packed tensor satisfies
145// stride_0 = 1
146// stride_i = stride_{i - 1} * extent_{i - 1}
147// So tensor
148// [G0, G1, G2, M, N]
149// transposed into tensor
150// [G0, G2, G1, M, N]
151// with strides
152// [G2 * G1 * M * N, G1 * M * N, M * N, N, 1]
153// is again a packed tensor. MakeGridDescriptor() currently just merges dimensions and ignores some
154// strides from input tensor extents so finer dimension information is lost. Merging dimensions is
155// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
156//
157// Might need to expose dimension order to the interface to fully support
158// TensorSpecialization::Packed in a traditional sense of "packed" tensor
159template <index_t NumDimG,
160 index_t NumDimM,
161 index_t NumDimN,
162 index_t NumDimK,
163 typename ADataType,
164 typename BDataType,
165 typename AccDataType,
166 typename CShuffleDataType,
167 typename DsDataType,
168 typename EDataType,
169 typename AElementwiseOperation,
170 typename BElementwiseOperation,
171 typename CDEElementwiseOperation,
172 GemmSpecialization GemmSpec,
176 index_t NumGemmKPrefetchStage,
177 index_t BlockSize,
178 index_t MPerBlock,
179 index_t NPerBlock,
180 index_t KPerBlock,
181 index_t AK1,
182 index_t BK1,
183 index_t MPerXDL,
184 index_t NPerXDL,
185 index_t MXdlPerWave,
186 index_t NXdlPerWave,
187 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
188 typename ABlockTransferThreadClusterArrangeOrder,
189 typename ABlockTransferSrcAccessOrder,
190 index_t ABlockTransferSrcVectorDim,
191 index_t ABlockTransferSrcScalarPerVector,
192 index_t ABlockTransferDstScalarPerVector_AK1,
193 bool ABlockLdsExtraM,
194 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
195 typename BBlockTransferThreadClusterArrangeOrder,
196 typename BBlockTransferSrcAccessOrder,
197 index_t BBlockTransferSrcVectorDim,
198 index_t BBlockTransferSrcScalarPerVector,
199 index_t BBlockTransferDstScalarPerVector_BK1,
200 bool BBlockLdsExtraN,
201 index_t CShuffleMXdlPerWavePerShuffle,
202 index_t CShuffleNXdlPerWavePerShuffle,
203 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
204 index_t CDEBlockTransferScalarPerVector_NPerBlock,
207 : public DeviceBatchedContractionMultipleD<NumDimG,
208 NumDimM,
209 NumDimN,
210 NumDimK,
211 ADataType,
212 BDataType,
213 DsDataType,
214 EDataType,
215 AElementwiseOperation,
216 BElementwiseOperation,
217 CDEElementwiseOperation>
218{
220
222 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
223 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
224
225 static constexpr index_t NumDTensor = DsDataType::Size();
226
227 static constexpr auto I0 = Number<0>{};
228 static constexpr auto I1 = Number<1>{};
229 static constexpr auto I2 = Number<2>{};
230 static constexpr auto I3 = Number<3>{};
231
232 static constexpr auto matrix_padder =
233 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
234
235 // Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
236 static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
237 const std::vector<index_t>& a_gs_ms_ks_strides_vec)
238 {
239 assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK &&
240 a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK);
241
242 const auto to_tuple = [&](auto& vec, auto start, auto end) {
243 return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
244 };
245
246 const auto a_ms_ks_lengths = to_tuple(
247 a_gs_ms_ks_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimK>{});
248 const auto a_ms_ks_strides = to_tuple(
249 a_gs_ms_ks_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimK>{});
250
251 // dimension Ids for M0, M1, ...
252 constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
253
254 // dimension Ids for K0, K1, ...
255 constexpr auto kDimIds =
257
258 // lengths for M0, M1, ...
259 const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds);
260
261 // lengths for K0, K1, ...
262 const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
263
264 if constexpr(ASpec == TensorSpecialization::Packed)
265 {
266 auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
267 auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
268 const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor(
269 make_tuple(M, K),
270 make_tuple(a_ms_ks_strides[Number<NumDimM - 1>{}],
271 a_ms_ks_strides[Number<NumDimM + NumDimK - 1>{}]));
272 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
273 }
274 else
275 {
276 // naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
277 const auto a_grid_desc_ms_ks =
278 make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
279
280 // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
281 const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
282 a_grid_desc_ms_ks,
284 make_tuple(mDimIds, kDimIds),
286
287 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
288 }
289 }
290
291 // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
292 static auto MakeBGridDescriptor_N_K(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
293 const std::vector<index_t>& b_gs_ns_ks_strides_vec)
294 {
295 assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK &&
296 b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK);
297
298 const auto to_tuple = [&](auto& vec, auto start, auto end) {
299 return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
300 };
301
302 const auto b_ns_ks_lengths = to_tuple(
303 b_gs_ns_ks_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimN + NumDimK>{});
304 const auto b_ns_ks_strides = to_tuple(
305 b_gs_ns_ks_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimN + NumDimK>{});
306
307 // dimension Ids for N0, N1, ...
308 constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
309
310 // dimension Ids for K0, K1, ...
311 constexpr auto kDimIds =
313
314 // lengths for K0, K1, ...
315 const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds);
316
317 // lengths for N0, N1, ...
318 const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
319
320 if constexpr(BSpec == TensorSpecialization::Packed)
321 {
322 auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
323 auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
324 const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor(
325 make_tuple(N, K),
326 make_tuple(b_ns_ks_strides[Number<NumDimN - 1>{}],
327 b_ns_ks_strides[Number<NumDimN + NumDimK - 1>{}]));
328 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
329 }
330 else
331 {
332 // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
333 const auto b_grid_desc_ns_ks =
334 make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
335
336 // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
337 const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
338 b_grid_desc_ns_ks,
340 make_tuple(nDimIds, kDimIds),
342
343 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
344 }
345 }
346
347 // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
348 static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_gs_ms_ns_lengths_vec,
349 const std::vector<index_t>& e_gs_ms_ns_strides_vec)
350 {
351 assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
352 e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
353
354 const auto to_tuple = [&](auto& vec, auto start, auto end) {
355 return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
356 };
357
358 const auto e_ms_ns_lengths = to_tuple(
359 e_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
360 const auto e_ms_ns_strides = to_tuple(
361 e_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
362
363 // dimension Ids for M0, M1, ...
364 constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
365
366 // dimension Ids for N0, N1, ...
367 constexpr auto nDimIds =
369
370 // lengths for M0, M1, ...
371 const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds);
372
373 // lengths for K0, K1, ...
374 const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds);
375
376 if constexpr(DESpec == TensorSpecialization::Packed)
377 {
378 auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
379 auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
380 const auto e_grid_desc_mraw_nraw = make_naive_tensor_descriptor(
381 make_tuple(M, N),
382 make_tuple(e_ms_ns_strides[Number<NumDimM - 1>{}],
383 e_ms_ns_strides[Number<NumDimM + NumDimN - 1>{}]));
384 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
385 }
386 else
387 {
388 // naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
389 const auto e_grid_desc_ms_ns =
390 make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides);
391
392 // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
393 const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor(
394 e_grid_desc_ms_ns,
396 make_tuple(mDimIds, nDimIds),
398
399 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
400 }
401 }
402
403 // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
404 static auto MakeEGridDescriptor_G_M_N(const std::vector<index_t>& e_gs_ms_ns_lengths_vec,
405 const std::vector<index_t>& e_gs_ms_ns_strides_vec)
406 {
407 assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
408 e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
409
410 const auto to_tuple = [&](auto& vec, auto start, auto end) {
411 return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
412 };
413
414 const auto e_gs_ms_ns_lengths =
415 to_tuple(e_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
416 const auto e_gs_ms_ns_strides =
417 to_tuple(e_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
418
419 // dimension Ids for G0, G1, ...
420 constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
421
422 // dimension Ids for M0, M1, ...
423 constexpr auto mDimIds =
425
426 // dimension Ids for N0, N1, ...
427 constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
428 NumDimG + NumDimM + NumDimN,
429 1>::type{};
430
431 // lengths for G0, G1, ...
432 const auto gLengths = get_container_subset(e_gs_ms_ns_lengths, gDimIds);
433
434 // lengths for M0, M1, ...
435 const auto mLengths = get_container_subset(e_gs_ms_ns_lengths, mDimIds);
436
437 // lengths for K0, K1, ...
438 const auto nLengths = get_container_subset(e_gs_ms_ns_lengths, nDimIds);
439
440 if constexpr(DESpec == TensorSpecialization::Packed)
441 {
442 auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{});
443 auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
444 auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
445 const auto e_grid_desc_g_mraw_nraw = make_naive_tensor_descriptor(
446 make_tuple(G, M, N),
447 make_tuple(e_gs_ms_ns_strides[Number<NumDimG - 1>{}],
448 e_gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
449 e_gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
450 // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
451 return e_grid_desc_g_mraw_nraw;
452 }
453 else
454 {
455 // naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
456 const auto e_grid_desc_gs_ms_ns =
457 make_naive_tensor_descriptor(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
458
459 // transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
460 // N2 * ...]
461 const auto e_grid_desc_g_mraw_nraw = transform_tensor_descriptor(
462 e_grid_desc_gs_ms_ns,
464 make_merge_transform(mLengths),
465 make_merge_transform(nLengths)),
466 make_tuple(gDimIds, mDimIds, nDimIds),
468
469 // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
470 return e_grid_desc_g_mraw_nraw;
471 }
472 }
473
475 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths_vec,
476 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides_vec)
477 {
478 return generate_tuple(
479 [&](auto i) {
480 return DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths_vec[i],
481 ds_gs_ms_ns_strides_vec[i]);
482 },
484 }
485
487 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths_vec,
488 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides_vec)
489 {
490 return generate_tuple(
491 [&](auto i) {
492 return DeviceOp::MakeEGridDescriptor_G_M_N(ds_gs_ms_ns_lengths_vec[i],
493 ds_gs_ms_ns_strides_vec[i]);
494 },
496 }
497
498 using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {}));
499 using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {}));
501 using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
502
505
507 {
509 index_t batch_stride_B,
510 DsGridDesc_G_M_N ds_grid_desc_g_m_n,
511 EGridDesc_G_M_N e_grid_desc_g_m_n)
512 : batch_stride_A_(batch_stride_A),
513 batch_stride_B_(batch_stride_B),
514 ds_grid_desc_g_m_n_(ds_grid_desc_g_m_n),
515 e_grid_desc_g_m_n_(e_grid_desc_g_m_n)
516 {
517 }
518
519 __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
520 {
521 return static_cast<long_index_t>(g_idx) * batch_stride_A_;
522 }
523
524 __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
525 {
526 return static_cast<long_index_t>(g_idx) * batch_stride_B_;
527 }
528
529 __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
530 {
531 std::array<long_index_t, NumDTensor> ds_offset;
532
533 static_for<0, NumDTensor, 1>{}([&](auto i) {
534 ds_offset[i] = static_cast<long_index_t>(g_idx) *
535 ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0));
536 });
537
538 return ds_offset;
539 }
540
541 __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
542 {
543 return static_cast<long_index_t>(g_idx) *
544 e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0));
545 }
546
547 private:
548 index_t batch_stride_A_;
549 index_t batch_stride_B_;
550 DsGridDesc_G_M_N ds_grid_desc_g_m_n_;
551 EGridDesc_G_M_N e_grid_desc_g_m_n_;
552 };
553
554 using ComputeDataType = ADataType;
555
556 // GridwiseGemm
557 template <index_t NXdlPerWave_>
559 ADataType,
560 BDataType,
562 AccDataType,
563 CShuffleDataType,
564 DsDataType,
565 EDataType,
566 AElementwiseOperation,
567 BElementwiseOperation,
568 CDEElementwiseOperation,
569 NumGemmKPrefetchStage,
570 BlockSize,
571 MPerBlock,
572 NPerBlock,
573 KPerBlock,
574 AK1,
575 BK1,
576 MPerXDL,
577 NPerXDL,
578 MXdlPerWave,
579 NXdlPerWave_,
580 ABlockTransferThreadClusterLengths_AK0_M_AK1,
581 ABlockTransferThreadClusterArrangeOrder,
582 ABlockTransferSrcAccessOrder,
583 ABlockTransferSrcVectorDim,
584 ABlockTransferSrcScalarPerVector,
585 ABlockTransferDstScalarPerVector_AK1,
586 false,
587 ABlockLdsExtraM,
588 BBlockTransferThreadClusterLengths_BK0_N_BK1,
589 BBlockTransferThreadClusterArrangeOrder,
590 BBlockTransferSrcAccessOrder,
591 BBlockTransferSrcVectorDim,
592 BBlockTransferSrcScalarPerVector,
593 BBlockTransferDstScalarPerVector_BK1,
594 false,
595 BBlockLdsExtraN,
596 CShuffleMXdlPerWavePerShuffle,
597 CShuffleNXdlPerWavePerShuffle,
598 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
599 CDEBlockTransferScalarPerVector_NPerBlock,
600 LoopSched>;
603
604 // desc for blockwise copy
607 AGridDesc_M_K{}))>;
610 BGridDesc_N_K{}))>;
613 DsGridDesc_M_N{}))>;
616 EGridDesc_M_N{}))>;
617
618 // block-to-e-tile map
621
622 // Argument
623 struct Argument : public BaseArgument
624 {
625 template <typename GridwiseGemm>
627 {
628 if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
633 {
635 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
637
639 GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
641 }
642 }
643
644 Argument(const void* p_a_grid,
645 const void* p_b_grid,
646 std::array<const void*, NumDTensor> p_ds_grid,
647 void* p_e_grid,
648 const std::vector<index_t>& a_gs_ms_ns_lengths,
649 const std::vector<index_t>& a_gs_ms_ks_strides,
650 const std::vector<index_t>& b_gs_ns_ks_lengths,
651 const std::vector<index_t>& b_gs_ns_ks_strides,
652 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
653 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
654 const std::vector<index_t>& e_gs_ms_ns_lengths,
655 const std::vector<index_t>& e_gs_ms_ns_strides,
656 AElementwiseOperation a_element_op,
657 BElementwiseOperation b_element_op,
658 CDEElementwiseOperation cde_element_op)
659 : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
660 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
661 p_ds_grid_{},
662 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
664 DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ns_lengths, a_gs_ms_ks_strides)},
666 DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
669 DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)},
671 DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)},
673 DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)},
675 GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
677 GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
680 block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
681 a_element_op_{a_element_op},
682 b_element_op_{b_element_op},
683 cde_element_op_{cde_element_op},
684 a_mz_stride_{},
685 a_kz_stride_{},
686 b_nz_stride_{},
687 b_kz_stride_{},
689 e_nz_stride_{},
690 a_batch_stride_{a_gs_ms_ks_strides[NumDimG - 1]},
691 b_batch_stride_{b_gs_ns_ks_strides[NumDimG - 1]},
694 {
695 static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0, "");
696
697 // populate pointer, batch stride, desc for Ds
698 static_for<0, NumDTensor, 1>{}([&](auto i) {
699 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
700
701 // D pointer
702 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
703
704 // D desc
705 ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths[i],
706 ds_gs_ms_ns_strides[i]);
707 });
708
709 // populate desc for Ds/E
710 if(get_warp_size() == 64)
711 {
712 if constexpr(NXdlPerWave64 > 0)
713 {
715 }
716 }
717 else
718 {
719 if constexpr(NXdlPerWave32 > 0)
720 {
722 }
723 }
724
725 // for sanity check of vector memory access
726 a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1];
727 a_kz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1];
728 b_nz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN - 1];
729 b_kz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1];
730
731 for(index_t i = 0; i < NumDTensor; ++i)
732 {
733 ds_nz_stride_[i] = ds_gs_ms_ns_strides[i][NumDimG + NumDimM + NumDimN - 1];
734 }
735
736 e_nz_stride_ = e_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1];
737 }
738
739 void Print() const
740 {
741 std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
742 std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
744 [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
745 std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
746 }
747
748 // private:
749 // pointers
750 const ADataType* p_a_grid_;
751 const BDataType* p_b_grid_;
753 EDataType* p_e_grid_;
754
755 // tensor descriptors for problem definiton
760
763
764 // tensor descriptors for block/thread-wise copy
770
771 // block-to-e-tile map
773
774 // element-wise op
775 AElementwiseOperation a_element_op_;
776 BElementwiseOperation b_element_op_;
777 CDEElementwiseOperation cde_element_op_;
778
779 // Strides for the last M/N/K dimensions of A/B/Ds/E
780 // for sanity check of vector load/store
785 std::array<index_t, NumDTensor> ds_nz_stride_;
788
791
793 };
794
795 // Invoker
796 struct Invoker : public BaseInvoker
797 {
799
800 template <typename GridwiseGemm>
801 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
802 {
803 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
808 {
809 throw std::runtime_error(
810 "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting");
811 }
812
813 const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0);
814
815 const index_t grid_size =
816 arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G;
817
818 const auto K =
819 arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
820
821 auto launch_kernel = [&](auto has_main_k_block_loop) {
822 constexpr bool has_main_loop = has_main_k_block_loop.value;
823
825 GridwiseGemm,
826 ADataType, // TODO: distiguish A/B datatype
827 typename GridwiseGemm::DsGridPointer,
828 EDataType,
829 AElementwiseOperation,
830 BElementwiseOperation,
831 CDEElementwiseOperation,
836 ComputePtrOffsetOfStridedBatch,
838 has_main_loop>;
839
840 return launch_and_time_kernel(stream_config,
841 kernel,
842 dim3(grid_size),
843 dim3(BlockSize),
844 0,
845 arg.p_a_grid_,
846 arg.p_b_grid_,
847 arg.p_ds_grid_,
848 arg.p_e_grid_,
849 G,
850 arg.a_element_op_,
851 arg.b_element_op_,
852 arg.cde_element_op_,
859 };
860
861 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
862 {
863 return launch_kernel(integral_constant<bool, true>{});
864 }
865 else
866 {
867 return launch_kernel(integral_constant<bool, false>{});
868 }
869 }
871
872 // polymorphic
873 float Run(const BaseArgument* p_arg,
874 const StreamConfig& stream_config = StreamConfig{}) override
875 {
876 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
877 }
878 };
879
880 static bool IsSupportedArgument(const Argument& arg)
881 {
883 {
884 return false;
885 }
886
887 bool valid = false;
888 if(get_warp_size() == 64)
889 {
890 if constexpr(NXdlPerWave64 > 0)
891 {
897 }
898 }
899 else
900 {
901 if constexpr(NXdlPerWave32 > 0)
902 {
908 }
909 }
910 if(!valid)
911 {
912 return false;
913 }
914
915 // check vector access
916 static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) &&
917 (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
918 "wrong!");
919
920 // vector memory access of A: could be on M or AK1 dimension
921 if constexpr(ABlockTransferSrcVectorDim == 1)
922 {
923 if(!(arg.a_mz_stride_ == 1 &&
924 arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
925 {
926 return false;
927 }
928 }
929 else
930 {
931 if(!(arg.a_kz_stride_ == 1 &&
932 arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
933 {
934 return false;
935 }
936 }
937
938 // vector memory access of B: could be on N or BK1 dimension
939 if constexpr(BBlockTransferSrcVectorDim == 1)
940 {
941 if(!(arg.b_nz_stride_ == 1 &&
942 arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
943 {
944 return false;
945 }
946 }
947 else
948 {
949 if(!(arg.b_kz_stride_ == 1 &&
950 arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
951 {
952 return false;
953 }
954 }
955
956 // vector memory access of Ds: always on NPerBlock dimension
957 bool valid_d_access = true;
958
959 static_for<0, NumDTensor, 1>{}([&](auto i) {
960 if(!(arg.ds_nz_stride_[i] == 1 &&
962 CDEBlockTransferScalarPerVector_NPerBlock ==
963 0))
964 {
965 valid_d_access = false;
966 }
967 });
968
969 if(valid_d_access == false)
970 {
971 return false;
972 }
973
974 // vector memory access of E: always on NPerBlock dimension
975 if(!((arg.e_nz_stride_ == 1 &&
977 CDEBlockTransferScalarPerVector_NPerBlock ==
978 0) ||
979 CDEBlockTransferScalarPerVector_NPerBlock == 1))
980 {
981 return false;
982 }
983
984 return true;
985 }
986
987 // polymorphic
988 bool IsSupportedArgument(const BaseArgument* p_arg) override
989 {
990 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
991 }
992
993 static auto
994 MakeArgument(const void* p_a,
995 const void* p_b,
996 std::array<const void*, NumDTensor> p_ds,
997 void* p_e,
998 const std::vector<index_t>& a_gs_ms_ns_lengths,
999 const std::vector<index_t>& a_gs_ms_ks_strides,
1000 const std::vector<index_t>& b_gs_ns_ks_lengths,
1001 const std::vector<index_t>& b_gs_ns_ks_strides,
1002 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
1003 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
1004 const std::vector<index_t>& e_gs_ms_ns_lengths,
1005 const std::vector<index_t>& e_gs_ms_ns_strides,
1006 AElementwiseOperation a_element_op,
1007 BElementwiseOperation b_element_op,
1008 CDEElementwiseOperation cde_element_op)
1009 {
1010 return Argument{p_a,
1011 p_b,
1012 p_ds,
1013 p_e,
1014 a_gs_ms_ns_lengths,
1015 a_gs_ms_ks_strides,
1016 b_gs_ns_ks_lengths,
1017 b_gs_ns_ks_strides,
1018 ds_gs_ms_ns_lengths,
1019 ds_gs_ms_ns_strides,
1020 e_gs_ms_ns_lengths,
1021 e_gs_ms_ns_strides,
1022 a_element_op,
1023 b_element_op,
1024 cde_element_op};
1025 }
1026
1027 static auto MakeInvoker() { return Invoker{}; }
1028
1029 // polymorphic
1030 std::unique_ptr<BaseArgument>
1031 MakeArgumentPointer(const void* p_a,
1032 const void* p_b,
1033 std::array<const void*, NumDTensor> p_ds,
1034 void* p_e,
1035 const std::vector<index_t>& a_gs_ms_ns_lengths,
1036 const std::vector<index_t>& a_gs_ms_ks_strides,
1037 const std::vector<index_t>& b_gs_ns_ks_lengths,
1038 const std::vector<index_t>& b_gs_ns_ks_strides,
1039 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
1040 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
1041 const std::vector<index_t>& e_gs_ms_ns_lengths,
1042 const std::vector<index_t>& e_gs_ms_ns_strides,
1043 AElementwiseOperation a_element_op,
1044 BElementwiseOperation b_element_op,
1045 CDEElementwiseOperation cde_element_op) override
1046 {
1047 return std::make_unique<Argument>(p_a,
1048 p_b,
1049 p_ds,
1050 p_e,
1051 a_gs_ms_ns_lengths,
1052 a_gs_ms_ks_strides,
1053 b_gs_ns_ks_lengths,
1054 b_gs_ns_ks_strides,
1055 ds_gs_ms_ns_lengths,
1056 ds_gs_ms_ns_strides,
1057 e_gs_ms_ns_lengths,
1058 e_gs_ms_ns_strides,
1059 a_element_op,
1060 b_element_op,
1061 cde_element_op);
1062 }
1063
1064 // polymorphic
1065 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1066 {
1067 return std::make_unique<Invoker>(Invoker{});
1068 }
1069
1070 // polymorphic
1071 std::string GetTypeString() const override
1072 {
1073 auto str = std::stringstream();
1074
1075 // clang-format off
1076 str << "DeviceBatchedContractionMultipleD_Xdl_CShuffle"
1077 << "<"
1078 << NumDimG << ", "
1079 << NumDimM << ", "
1080 << NumDimN << ", "
1081 << NumDimK << ", "
1082 << BlockSize << ", "
1083 << MPerBlock << ", "
1084 << NPerBlock << ", "
1085 << KPerBlock << ", "
1086 << AK1 << ", "
1087 << BK1 << ", "
1088 << ABlockTransferSrcVectorDim << ", "
1089 << BBlockTransferSrcVectorDim
1090 << ">";
1091 // clang-format on
1092
1093 return str.str();
1094 }
1095};
1096
1097} // namespace device
1098} // namespace tensor_operation
1099} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
TensorSpecialization
Definition tensor_specialization.hpp:11
@ Packed
Definition tensor_specialization.hpp:13
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition get_id.hpp:49
__global__ void kernel_contraction_multiple_d_xdl_cshuffle(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatDsPointer p_ds_grid, FloatE *__restrict__ p_e_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map)
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:41
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
__host__ __device__ constexpr auto get_container_subset(const Array< T, N > &arr, Sequence< Is... >)
Definition utility/container_helper.hpp:346
__host__ __device__ constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition utility/container_helper.hpp:111
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
Definition utility/sequence.hpp:43
Definition utility/sequence.hpp:256
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition utility/integral_constant.hpp:20
Definition utility/math.hpp:34
Definition functional2.hpp:33
Definition device_base.hpp:197
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:519
ComputePtrOffsetOfStridedBatch(index_t batch_stride_A, index_t batch_stride_B, DsGridDesc_G_M_N ds_grid_desc_g_m_n, EGridDesc_G_M_N e_grid_desc_g_m_n)
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:508
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:541
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:524
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:529
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:624
EGridDesc_M_N e_grid_desc_m_n_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:759
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:769
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:752
BElementwiseOperation b_element_op_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:776
BGridDesc_N_K b_grid_desc_n_k_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:757
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:758
const BDataType * p_b_grid_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:751
EDataType * p_e_grid_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:753
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:792
index_t a_kz_stride_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:782
index_t b_nz_stride_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:783
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:768
index_t b_batch_stride_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:790
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:765
EGridDesc_G_M_N e_grid_desc_g_m_n_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:762
void init_ds_e_grid_desc()
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:626
index_t a_mz_stride_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:781
std::array< index_t, NumDTensor > ds_nz_stride_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:785
DsGridDesc_G_M_N ds_grid_desc_g_m_n_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:761
Block2ETileMap block_2_etile_map_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:772
const ADataType * p_a_grid_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:750
index_t e_mz_stride_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:786
index_t a_batch_stride_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:789
AGridDesc_M_K a_grid_desc_m_k_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:756
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:766
index_t b_kz_stride_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:784
CDEElementwiseOperation cde_element_op_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:777
index_t e_nz_stride_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:787
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, const std::vector< index_t > &a_gs_ms_ns_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides, const std::vector< index_t > &e_gs_ms_ns_lengths, const std::vector< index_t > &e_gs_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:644
AElementwiseOperation a_element_op_
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:775
void Print() const
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:739
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:797
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:801
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:873
DeviceOp::Argument Argument
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:798
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:218
static constexpr auto NXdlPerWave32
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:223
static constexpr auto matrix_padder
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:232
decltype(MakeBGridDescriptor_N_K({}, {})) BGridDesc_N_K
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:499
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:619
static auto MakeAGridDescriptor_M_K(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:236
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:608
remove_cvref_t< decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:614
remove_cvref_t< decltype(MakeDsGridDescriptor_G_M_N({}, {}))> DsGridDesc_G_M_N
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:503
static auto MakeEGridDescriptor_M_N(const std::vector< index_t > &e_gs_ms_ns_lengths_vec, const std::vector< index_t > &e_gs_ms_ns_strides_vec)
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:348
DeviceBatchedContractionMultipleD_Xdl_CShuffle DeviceOp
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:219
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:1065
static constexpr index_t NumDTensor
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:225
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:880
static auto MakeDsGridDescriptor_G_M_N(const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths_vec, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides_vec)
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:486
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::vector< index_t > &a_gs_ms_ns_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides, const std::vector< index_t > &e_gs_ms_ns_lengths, const std::vector< index_t > &e_gs_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:1031
decltype(MakeAGridDescriptor_M_K({}, {})) AGridDesc_M_K
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:498
static constexpr auto I3
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:230
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:988
remove_cvref_t< decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:611
decltype(MakeEGridDescriptor_M_N({}, {})) EGridDesc_M_N
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:501
static constexpr auto I2
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:229
static auto MakeBGridDescriptor_N_K(const std::vector< index_t > &b_gs_ns_ks_lengths_vec, const std::vector< index_t > &b_gs_ns_ks_strides_vec)
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:292
decltype(MakeEGridDescriptor_G_M_N({}, {})) EGridDesc_G_M_N
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:504
static auto MakeInvoker()
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:1027
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:602
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::vector< index_t > &a_gs_ms_ns_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides, const std::vector< index_t > &e_gs_ms_ns_lengths, const std::vector< index_t > &e_gs_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:994
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:222
static auto MakeEGridDescriptor_G_M_N(const std::vector< index_t > &e_gs_ms_ns_lengths_vec, const std::vector< index_t > &e_gs_ms_ns_strides_vec)
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:404
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:601
std::string GetTypeString() const override
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:1071
static constexpr auto I0
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:227
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))> DsGridDesc_M_N
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:500
static constexpr auto I1
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:228
GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemmBase
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:558
static auto MakeDsGridDescriptor_M_N(const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths_vec, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides_vec)
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:474
ADataType ComputeDataType
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:554
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:605
Definition device_batched_contraction_multiple_d.hpp:39
Definition matrix_padder.hpp:180