gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp Source File

gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp Source File#

Composable Kernel: gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp Source File
gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// Default policy for GemmPipelineAGmemBGmemCRegV1
12// Default policy class should not be templated, put template on member functions instead
14{
15 static constexpr auto I0 = number<0>{};
16 static constexpr auto I1 = number<1>{};
17 static constexpr auto I2 = number<2>{};
18
19 // 3d + padding
20 template <typename Problem>
22 {
23 using namespace ck_tile;
24
25 constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
26 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
27
28 constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
30 make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
31 number<8>{},
32 number<1>{});
33
34 constexpr auto a_lds_block_desc = transform_tensor_descriptor(
35 a_lds_block_desc_0,
37 make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
40
41 return a_lds_block_desc;
42 }
43
44 // 3d + padding
45 template <typename Problem>
47 {
48 constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
49 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
50
51 constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
53 make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
54 number<8>{},
55 number<1>{});
56
57 constexpr auto b_lds_block_desc = transform_tensor_descriptor(
58 b_lds_block_desc_0,
60 make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
63
64 return b_lds_block_desc;
65 }
66
67 template <typename Problem>
69 {
70 constexpr index_t PackedSize =
72 constexpr index_t smem_size_a =
73 sizeof(typename Problem::ADataType) *
74 MakeALdsBlockDescriptor<Problem>().get_element_space_size() / PackedSize;
75 return smem_size_a;
76 }
77
78 template <typename Problem>
80 {
81 constexpr index_t PackedSize =
83 constexpr index_t smem_size_b =
84 sizeof(typename Problem::BDataType) *
85 MakeBLdsBlockDescriptor<Problem>().get_element_space_size() / PackedSize;
86 return smem_size_b;
87 }
88
89 template <typename Problem>
91 {
92 constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
93 constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
94 constexpr index_t smem_size = smem_size_a + smem_size_b;
95
96 return smem_size;
97 }
98
99 template <typename Problem>
100 CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
101 {
102 return Problem::VectorLoadSize;
103 }
104
105 template <typename Problem>
106 CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
107 {
108 return Problem::VectorLoadSize;
109 }
110
111 template <typename Problem>
113 {
116
117 constexpr index_t BlockSize = Problem::kBlockSize;
118
119 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
120 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
121
122 if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
123 {
124 constexpr index_t M1 = Problem::VectorSizeA;
125 constexpr index_t M0 = MPerBlock / M1;
126 constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
127 static_assert(total_pixels % M1 == 0);
128 constexpr index_t K3 = total_pixels / M1;
129 constexpr index_t KPack = GetSmemPackA<Problem>();
130 static_assert(KPack % K3 == 0);
131 constexpr index_t K2 = KPack / K3;
132 if constexpr(get_warp_size() >= (K2 * M0))
133 {
134 constexpr index_t K1 = get_warp_size() / (K2 * M0);
135 constexpr index_t K0 = BlockSize / get_warp_size();
136 static_assert(KPerBlock == K0 * K1 * K2 * K3);
143 sequence<3, 1>>{});
144 }
145 else
146 {
147 constexpr index_t K1 = (K2 * M0) / get_warp_size();
148 constexpr index_t K2_m = K2 / K1;
149 constexpr index_t K0 = BlockSize / get_warp_size() / K1;
150 static_assert(KPerBlock == K0 * K1 * K2_m * K3);
157 sequence<3, 1>>{});
158 }
159 }
160 else
161 {
162 constexpr index_t K1 = 16 / sizeof(ADataType);
163 constexpr index_t K0 = KPerBlock / K1;
164 constexpr index_t M2 = get_warp_size() / K0;
165 // coalesce reading for each blocks
166 if constexpr(get_warp_size() % (M2 * K0) == 0)
167 {
168 constexpr index_t M1 = BlockSize / get_warp_size();
169 static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
170 static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
171 constexpr index_t M0 = MPerBlock / (M2 * M1);
172 static_assert(M0 * M1 * M2 == MPerBlock,
173 "Incorrect M0, M2, M1 configuration! "
174 "M0, M1, M2 must cover whole MPerBlock!");
181 sequence<0, 1>>{});
182 }
183 else
184 {
185 constexpr index_t M0 = BlockSize / get_warp_size();
186 constexpr index_t M1 = MPerBlock / (M2 * M0);
187 static_assert(M0 * M1 * M2 == MPerBlock,
188 "Incorrect M0, M1, M2 configuration! "
189 "M0, M1, M2 must cover whole MPerBlock!");
196 sequence<1, 1>>{});
197 }
198 }
199 }
200
201 template <typename Problem>
203 {
206
207 constexpr index_t BlockSize = Problem::kBlockSize;
208
209 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
210 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
211
212 if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
213 {
214 constexpr index_t N1 = Problem::VectorSizeB;
215 constexpr index_t N0 = NPerBlock / N1;
216 constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
217 static_assert(total_pixels % N1 == 0);
218 constexpr index_t K3 = total_pixels / N1;
219 constexpr index_t KPack = GetSmemPackB<Problem>();
220 static_assert(KPack % K3 == 0);
221 constexpr index_t K2 = KPack / K3;
222 if constexpr(get_warp_size() >= (K2 * N0))
223 {
224 constexpr index_t K1 = get_warp_size() / (K2 * N0);
225 constexpr index_t K0 = BlockSize / get_warp_size();
226 static_assert(KPerBlock == K0 * K1 * K2 * K3);
233 sequence<3, 1>>{});
234 }
235 else
236 {
237 constexpr index_t K1 = (K2 * N0) / get_warp_size();
238 constexpr index_t K2_m = K2 / K1;
239 constexpr index_t K0 = BlockSize / get_warp_size() / K1;
240 static_assert(KPerBlock == K0 * K1 * K2_m * K3);
247 sequence<3, 1>>{});
248 }
249 }
250 else
251 {
252
253 constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType);
254 constexpr index_t K0 = KPerBlock / K1;
255 constexpr index_t N2 = get_warp_size() / K0;
256 // coalesce reading for each blocks
257 if constexpr(get_warp_size() % (N2 * K0) == 0)
258 {
259 constexpr index_t N1 = BlockSize / get_warp_size();
260 static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
261 static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
262 constexpr index_t N0 = NPerBlock / (N2 * N1);
263 static_assert(N0 * N1 * N2 == NPerBlock,
264 "Incorrect N0, N1, N2 configuration! "
265 "N0, N1, N2 must cover whole NPerBlock!");
266
273 sequence<0, 1>>{});
274 }
275 // coalesce reading for each warps
276 else
277 {
278 constexpr index_t N0 = BlockSize / get_warp_size();
279 constexpr index_t N1 = NPerBlock / (N2 * N0);
280 static_assert(N0 * N1 * N2 == NPerBlock,
281 "Incorrect N0, N1, N2 configuration! "
282 "N0, N1, N2 must cover whole NPerBlock!");
289 sequence<1, 1>>{});
290 }
291 }
292 }
293
294 template <typename Problem>
296 {
299 static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
300 constexpr index_t kBlockSize = Problem::kBlockSize;
301 constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
302 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
303
304 constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
305 constexpr index_t N0 = kNPerBlock / N1;
306 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
307 static_assert(total_pixels % N1 == 0);
308 constexpr index_t K3 = total_pixels / N1;
309 constexpr index_t kKPack = GetSmemPackB<Problem>();
310 static_assert(kKPack % K3 == 0);
311 constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
312 constexpr index_t warp_size = get_warp_size();
313 if constexpr(warp_size % (K2 * N0) == 0)
314 {
315 constexpr index_t K1 = warp_size / (K2 * N0);
316 constexpr index_t K0 = kBlockSize / warp_size;
317
324 sequence<1, 3>>{});
325 }
326 else
327 {
328 constexpr index_t K1 = (K2 * N0) / get_warp_size();
329 constexpr index_t K2_m = K2 / K1;
330 constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
331 static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
338 sequence<1, 3>>{});
339 }
340 }
341
342 template <typename Problem>
344 {
347 static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
348 constexpr index_t kBlockSize = Problem::kBlockSize;
349 constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
350 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
351
352 constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
353 constexpr index_t M0 = kMPerBlock / M1;
354 constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
355 static_assert(total_pixels % M1 == 0);
356 constexpr index_t K3 = total_pixels / M1;
357 constexpr index_t kKPack = GetSmemPackA<Problem>();
358 static_assert(kKPack % K3 == 0);
359 constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
360 constexpr index_t warp_size = get_warp_size();
361 if constexpr(warp_size % (K2 * M0) == 0)
362 {
363 constexpr index_t K1 = warp_size / (K2 * M0);
364 constexpr index_t K0 = kBlockSize / warp_size;
365
372 sequence<1, 3>>{});
373 }
374 else
375 {
376 constexpr index_t K1 = (K2 * M0) / get_warp_size();
377 constexpr index_t K2_m = K2 / K1;
378 constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
379 static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
386 sequence<1, 3>>{});
387 }
388 }
389
390 template <typename Problem>
391 CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
392 {
393 using AccDataType = float;
394 using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
395 using WarpTile = typename Problem::BlockGemmShape::WarpTile;
396 using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
397 typename Problem::ComputeDataType,
398 AccDataType,
399 WarpTile::at(I0),
400 WarpTile::at(I1),
401 WarpTile::at(I2),
402 Problem::TransposeC>;
403
404 using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
405 typename Problem::BDataType,
406 typename Problem::CDataType,
407 BlockWarps,
408 WarpGemm>;
409
411 }
412};
413
414} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp:18
Definition block_universal_gemm_as_bs_cr.hpp:21
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:14
static constexpr auto I2
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:17
static CK_TILE_HOST_DEVICE constexpr auto GetSmemPackA()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:100
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledBRegBlockDistribution()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:295
static CK_TILE_HOST_DEVICE constexpr auto GetSmemPackB()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:106
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeA()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:68
static constexpr auto I1
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:16
static CK_TILE_HOST_DEVICE constexpr auto MakeBDramTileDistribution()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:202
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:90
static CK_TILE_HOST_DEVICE constexpr auto GetBlockGemm()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:391
static CK_TILE_HOST_DEVICE constexpr auto MakeALdsBlockDescriptor()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:21
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledARegBlockDistribution()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:343
static CK_TILE_HOST_DEVICE constexpr auto MakeBLdsBlockDescriptor()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:46
static CK_TILE_HOST_DEVICE constexpr auto MakeADramTileDistribution()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:112
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeB()
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:79
static constexpr auto I0
Definition gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp:15
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192