mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp Source File

mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp Source File#

Composable Kernel: mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp Source File
mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck_tile {
9
10#define CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE 0
11
12#if defined(__gfx950__)
13#define CKTILE_FLATMM_ARCH_SUPPORT_BUFFER_LOAD_LDS_DWORDx4 1
14#else
15#define CKTILE_FLATMM_ARCH_SUPPORT_BUFFER_LOAD_LDS_DWORDx4 0
16#endif
17
18#define CKTILE_FLATMM_USE_BUFFER_LOAD_LDS \
19 (CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE && \
20 CKTILE_FLATMM_ARCH_SUPPORT_BUFFER_LOAD_LDS_DWORDx4)
21
23{
24 static constexpr auto I0 = number<0>{};
25 static constexpr auto I1 = number<1>{};
26 static constexpr auto I2 = number<2>{};
27
28 static constexpr index_t KBPerLoad = 32;
29 static constexpr index_t N_Pack = 2; // it's fixed for fp4
30 static constexpr index_t K_Pack = 2; // it's fixed for fp4
31
32 template <typename Problem, typename NativeADramTensorView>
33 CK_TILE_HOST_DEVICE static constexpr auto
34 TransformF16xF4_ATensorView(const NativeADramTensorView& a_dram_view)
35 {
36#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS
37 constexpr int DynamicTileOffsetFlag = 0;
38
39 constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
40 constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
41
42 static_assert(MPerXdl == 16 && NPerXdl == 16);
43
44 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
45 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
46 constexpr index_t KPack = GetSmemPackA<Problem>();
47
48 constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
49
50 // implement swizzle pattern on global side
51 // because we can't adjust the ds_write pattern of BUFFER_LOAD_LDS.
52 auto swizzle_a_dram_view_1 = transform_tensor_view(
53 a_dram_view,
55 // M-dim is not affected by swizzle pattern
58 // K-dim is the swizzle dimension
60 number<KPerBlock / KPack>{},
61 number<KPack>{}))),
64
65 auto swizzle_a_dram_view_2 = transform_tensor_view(
66 swizzle_a_dram_view_1,
74
76 swizzle_a_dram_view_2,
81 number<KPerBlock / KPack>{},
82 number<KPack>{}))),
85#else
86 return a_dram_view;
87#endif
88 }
89
90 template <typename Problem>
92 {
93 constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
94 constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
95
96 static_assert(MPerXdl == 16 && NPerXdl == 16);
97
98 /*reduce transform layers,compare with old ck*/
99 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
100 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
101 constexpr index_t KPack = GetSmemPackA<Problem>();
102
103 constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
107 number<1>{});
108
109 constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
110
111 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
112 a_lds_block_desc_0,
118
119 constexpr auto a_lds_block_desc = transform_tensor_descriptor(
120 a_lds_block_desc_permuted,
126
127 return a_lds_block_desc;
128 }
129
130 template <typename Problem>
132 {
133#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS
134 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
135 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
136 constexpr index_t KPack = GetSmemPackA<Problem>();
140 number<1>{});
141#else
143#endif
144 }
145
146 template <typename Problem>
148 {
149 using TileShape = typename Problem::BlockGemmShape;
150
151 static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
152 static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
153
154 constexpr int Repeat = TileShape::BlockWarps::at(number<1>{});
155 constexpr int M0 = TileShape::WarpTile::at(I0);
156
157 constexpr int K_Lane = 64 / TileShape::WarpTile::at(I1); // 4
158
159 constexpr int K2 = TileShape::WarpTile::at(I2) / K_Lane; // 8
160 constexpr int XDL_PerThreadK = KBPerLoad / K2; // 4
161 constexpr int K0 = K_Lane; // 4
162
169 sequence<2>>{});
170 }
171
172 template <typename Problem>
174 {
175 using TileShape = typename Problem::BlockGemmShape;
176
177 static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16");
178
179 constexpr index_t BlockSize = Problem::kBlockSize;
180 constexpr index_t WaveSize = get_warp_size();
181 constexpr index_t WaveNum = BlockSize / WaveSize;
182
183 constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
184 constexpr index_t KWavePerBlk = 1;
185
186 constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
187
188 constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
189
194 // direction
196 // wave in blk, // thd in wave
197 // <M, K> // <M, K>
198 tuple<sequence<0, 1, 2>, sequence<2>>, // which direction
199 tuple<sequence<0, 0, 0>, sequence<1>>, // which index
200 // <repeat, vec_load>
202 sequence<2>>{});
203 }
204
205 template <typename Problem>
207 {
208 using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
209
210 constexpr index_t BlockSize = Problem::kBlockSize;
211 constexpr index_t WaveSize = get_warp_size();
212 [[maybe_unused]] constexpr index_t WaveNum = BlockSize / WaveSize;
213
214 constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
215
216 [[maybe_unused]] constexpr index_t XDLPerBlock =
217 TileShape::kK / TileShape::WarpTile::at(I2);
218 constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
219 constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
220
221 constexpr index_t NWavePerBlk = N_Warp;
222
225 sequence<>, // ?
226 tuple<sequence<NWavePerBlk>, // second direction
228 // direction
229 // wave in blk, // thd in wave
230 // <M, K> // <M, K>
231 tuple<sequence<1>, sequence<2, 2>>, // which direction
232 tuple<sequence<0>, sequence<0, 1>>, // which index
233 // <repeat, vec_load>
235 sequence<2>>{});
236 }
237};
238
239} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1609
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_view.hpp:511
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:23
static CK_TILE_HOST_DEVICE constexpr auto TransformF16xF4_ATensorView(const NativeADramTensorView &a_dram_view)
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:34
static CK_TILE_HOST_DEVICE constexpr auto MakeF16xF4_WriteALdsBlockDescriptor()
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:131
static CK_TILE_HOST_DEVICE constexpr auto MakeF16xF4_ReadALdsBlockDescriptor()
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:91
static CK_TILE_HOST_DEVICE constexpr auto MakeFp4BFlatDramTileDistribution()
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:173
static CK_TILE_HOST_DEVICE constexpr auto MakeF16xF4_ALDS_TileDistribution()
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:147
static constexpr index_t N_Pack
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:29
static constexpr auto I0
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:24
static constexpr index_t KBPerLoad
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:28
static constexpr index_t K_Pack
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:30
static constexpr auto I1
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:25
static CK_TILE_HOST_DEVICE constexpr auto MakeFp4ScaleBFlatDramTileDistribution()
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:206
static constexpr auto I2
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:26
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:14
static CK_TILE_HOST_DEVICE constexpr auto GetSmemPackA()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:233
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192