gemm_group_quant_utils.hpp Source File

gemm_group_quant_utils.hpp Source File#

Composable Kernel: gemm_group_quant_utils.hpp Source File
gemm_group_quant_utils.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck_tile {
9
10template <typename Problem, typename DataType, index_t YPerTile, index_t XPerTile>
11CK_TILE_HOST_DEVICE static constexpr auto GetABQGlobalVectorLoadSize()
12{
13 using I1 = number<1>;
14 constexpr index_t NWarps = Problem::BlockGemmShape::BlockWarps::at(I1{});
15
16 constexpr index_t BlockSize = Problem::kBlockSize;
17
18 // Data is replicated across warps along NWarps, so we divide BlockSize by NWarps
19 constexpr index_t elements_per_thread = (YPerTile * XPerTile) / (BlockSize / NWarps);
20 constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
21
22 // Define vector load candidates in descending order of priority
23 constexpr std::array<index_t, 5> candidates{
24 PackedSize * 32 / sizeof(DataType),
25 PackedSize * 16 / sizeof(DataType),
26 PackedSize * 8 / sizeof(DataType),
27 PackedSize * 4 / sizeof(DataType),
28 PackedSize * 2 / sizeof(DataType),
29 };
30
31 for(const auto vec_size : candidates)
32 {
33 if(vec_size <= 0 || XPerTile % vec_size != 0 || elements_per_thread % vec_size != 0)
34 continue;
35 bool is_valid = (vec_size > 0) && (XPerTile % vec_size == 0) &&
36 (elements_per_thread % vec_size == 0) && vec_size != candidates[4];
37 if(is_valid)
38 {
39 return vec_size;
40 }
41 }
42 return PackedSize; // Absolute fallback
43}
44
45// AQ holds groupquant scale data for A. Data is loaded from DRAM and partitioned across
46// threads. Post mfma scales are shuffled across threads in the warp and applied to
47// accum registers.
48template <typename BlockGemmShape,
49 typename WarpGemm,
50 index_t BlockSize,
51 index_t YPerTile,
52 index_t XPerTile,
53 index_t KPerBlockAQ,
54 index_t VecSize,
55 bool PreshuffleQuant>
57{
58 static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
59 static constexpr index_t warp_size = get_warp_size();
60 static constexpr index_t num_warps = BlockSize / get_warp_size();
61
62 static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
63 static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
64 static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
65
66 static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM);
67
68 static_assert(num_warps == MWarps * NWarps * KWarps);
69
70 // KWarps > 1 isn't supported
71 static_assert(KWarps == 1);
72
74 {
75 if constexpr(PreshuffleQuant)
76 {
77 // # of elements per thread
78 static_assert(XPerTile >= warp_size && XPerTile % warp_size == 0);
79 constexpr index_t X1 = warp_size;
80 constexpr index_t X0 = XPerTile / warp_size;
81
82 constexpr index_t Y1 = MWarps;
83 constexpr index_t Y0 = YPerTile / Y1;
91 }
92 else
93 {
94 // # of elements per thread
95 constexpr index_t X = XPerTile;
96
97 constexpr index_t Y0 = 1;
98 constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
99 constexpr index_t Y2 = MWarps;
100 constexpr index_t Y3 = WarpGemm::kM;
101 static_assert(Y3 >= WarpGemm::kM,
102 "Scales for all rows must be available within the warp.");
103 static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
104 "Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
111 sequence<1, 0>>{});
112 }
113 }
114};
115
116template <typename BlockGemmShape,
117 typename WarpGemm,
118 index_t BlockSize,
119 index_t YPerTile,
120 index_t XPerTile,
121 index_t VecSize>
124{
125 // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
126 static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
127 static constexpr index_t warp_size = get_warp_size();
128 static constexpr index_t num_warps = BlockSize / get_warp_size();
129
130 static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
131 static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
132 static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
133
134 static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM);
135
136 static_assert(num_warps == MWarps * NWarps * KWarps);
137
138 // KWarps > 1 isn't supported
139 static_assert(KWarps == 1);
140
141 // # of elements per thread
142 static constexpr index_t X = XPerTile;
143 static constexpr index_t XR = 2;
144
145 // Number of iters per warp
146 // MIters are indexed using (Y0, Y1)
147 static constexpr index_t Y0 = MIterPerWarp;
148
149 // # of warps in Y dim
150 static constexpr index_t Y1 = MWarps;
151
152 static constexpr index_t Y2 = WarpGemm::kM;
153
154 static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
155
166};
167
168// TODO:: might need to update
169template <typename BlockGemmShape,
170 typename WarpGemm,
171 index_t BlockSize,
172 index_t YPerTile,
173 index_t XPerTile,
174 index_t XPerQ>
176{
177 static constexpr index_t warp_size = get_warp_size();
178 static constexpr index_t num_warps = BlockSize / get_warp_size();
179
180 static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
181 static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
182 static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
183
184 static constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarps * WarpGemm::kN);
185
186 static_assert(num_warps == MWarps * NWarps * KWarps);
187 static_assert(KWarps == 1);
188
215 {
216 if constexpr(XPerQ < WarpGemm::kN)
217 {
218 // Case 1: Fine-grained - multiple quantization scales within a single warp
219 constexpr index_t Y = YPerTile; // Full Y dimension of tile
220 constexpr index_t YR = 1; // No Y replication needed
221 constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim
222 constexpr index_t X1 = NWarps; // Number of warps in N-dim
223 constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp
224 constexpr index_t XR = XPerQ; // Elements per quantization group
225
226 static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along X.");
227
234 sequence<0, 0>>{});
235 }
236 else if constexpr(XPerQ <= WarpGemm::kN * NWarps)
237 {
238 // Case 2: Medium-grained - one quantization scale per warp
239 constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor
240 constexpr auto X1 = NWarps / XR; // Warps per unique scale
241 constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension
248 sequence<0, 0>>{});
249 }
250 else // XPerQ > WarpGemm::kN * NWarps
251 {
252 // Case 3: Coarse-grained - quantization group spans all warps
253 // All warps in N-dimension share the same quantization scale
260 sequence<0, 0>>{});
261 }
262 }
263};
264
265template <typename GroupSizes>
267{
268 static constexpr index_t kM = GroupSizes::at(number<0>{});
269 static constexpr index_t kN = GroupSizes::at(number<1>{});
270 static constexpr index_t kK = GroupSizes::at(number<2>{});
271
272 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
273 {
274 return concat('_', "quant_group_shape", concat('x', kM, kN, kK));
275 }
276};
277
278} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
Definition gemm_group_quant_utils.hpp:267
static constexpr index_t kM
Definition gemm_group_quant_utils.hpp:268
static constexpr index_t kK
Definition gemm_group_quant_utils.hpp:270
static constexpr index_t kN
Definition gemm_group_quant_utils.hpp:269
static CK_TILE_HOST const std::string GetName()
Definition gemm_group_quant_utils.hpp:272
Definition tile/core/container/sequence.hpp:49
Definition gemm_group_quant_utils.hpp:124
static constexpr index_t NWarps
Definition gemm_group_quant_utils.hpp:131
static constexpr index_t MWarps
Definition gemm_group_quant_utils.hpp:130
static constexpr index_t MIterPerWarp
Definition gemm_group_quant_utils.hpp:134
static constexpr index_t X
Definition gemm_group_quant_utils.hpp:142
static CK_TILE_HOST_DEVICE constexpr auto make_2d_static_tile_distribution()
Definition gemm_group_quant_utils.hpp:156
static constexpr index_t KWarps
Definition gemm_group_quant_utils.hpp:132
static constexpr index_t Y0
Definition gemm_group_quant_utils.hpp:147
static constexpr index_t num_warps
Definition gemm_group_quant_utils.hpp:128
static constexpr index_t Y2
Definition gemm_group_quant_utils.hpp:152
static constexpr index_t Y1
Definition gemm_group_quant_utils.hpp:150
static constexpr index_t XR
Definition gemm_group_quant_utils.hpp:143
static constexpr index_t warp_size
Definition gemm_group_quant_utils.hpp:127
Definition gemm_group_quant_utils.hpp:57
static constexpr index_t KWarps
Definition gemm_group_quant_utils.hpp:64
static constexpr index_t MWarps
Definition gemm_group_quant_utils.hpp:62
static CK_TILE_HOST_DEVICE constexpr auto make_2d_static_tile_distribution()
Definition gemm_group_quant_utils.hpp:73
static constexpr index_t warp_size
Definition gemm_group_quant_utils.hpp:59
static constexpr index_t NWarps
Definition gemm_group_quant_utils.hpp:63
static constexpr index_t num_warps
Definition gemm_group_quant_utils.hpp:60
static constexpr index_t MIterPerWarp
Definition gemm_group_quant_utils.hpp:66
Definition gemm_group_quant_utils.hpp:176
static constexpr index_t num_warps
Definition gemm_group_quant_utils.hpp:178
static constexpr index_t NWarps
Definition gemm_group_quant_utils.hpp:181
static constexpr index_t warp_size
Definition gemm_group_quant_utils.hpp:177
static constexpr index_t MWarps
Definition gemm_group_quant_utils.hpp:180
static CK_TILE_HOST_DEVICE constexpr auto make_2d_static_tile_distribution()
Creates a 2D tile distribution for BQ (B-matrix quantization scales).
Definition gemm_group_quant_utils.hpp:214
static constexpr index_t KWarps
Definition gemm_group_quant_utils.hpp:182
static constexpr index_t NIterPerWarp
Definition gemm_group_quant_utils.hpp:184
Definition static_encoding_pattern.hpp:108
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192