block_fmha_fwd_appendkv_pipeline_default_policy.hpp Source File

block_fmha_fwd_appendkv_pipeline_default_policy.hpp Source File#

Composable Kernel: block_fmha_fwd_appendkv_pipeline_default_policy.hpp Source File
block_fmha_fwd_appendkv_pipeline_default_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7namespace ck_tile {
8
9// This pipeline is qkv all located in LDS
11{
12 template <typename Problem>
13 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
14 {
16
17 return 16 / sizeof(QDataType);
18 }
19
20 template <typename Problem>
21 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
22 {
24
25 return 16 / sizeof(KDataType);
26 }
27
28 template <typename Problem>
29 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
30 {
33 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
34 {
35 constexpr index_t kBlockSize = Problem::kBlockSize;
36 constexpr index_t kNPerBlock = Problem::kN0;
37 constexpr index_t kKPerBlock = Problem::kN1;
38 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
39
40 // TODO: not correct!
41 if constexpr(total_pixels > 4)
42 return 4;
43 else
44 return 2;
45 }
46 else
47 {
48 return 16 / sizeof(VDataType);
49 }
50 }
51
52 template <typename Problem>
54 {
55 using DataType = typename Problem::QDataType;
56
57 if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
58 {
60 return 16 / sizeof(DataType);
61 }
62 else
63 {
64 return 16 / sizeof(DataType);
65 }
66 }
67
68 template <typename Problem>
70 {
71 static_assert(Problem::RotaryEnum != RotaryEmbeddingEnum::NONE);
72
73 if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
74 {
75 constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
76 static_assert(Problem::kK0 % KPerThread == 0);
77 constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
78 index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
79
80 return make_tuple(start_pos, start_pos + KPerThread);
81 }
82 else
83 {
84 constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
85 static_assert(Problem::kK0 % KPerThread == 0);
86 constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
87 index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
88
89 return make_tuple(start_pos, start_pos + KPerThread);
90 }
91 }
92
93 template <typename Problem>
95 {
96 constexpr index_t kBlockSize = Problem::kBlockSize;
97 constexpr index_t kMPerBlock = Problem::kM0;
98 constexpr index_t kKPerBlock = Problem::kK0;
99
100 constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
101 constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
102 constexpr index_t MThreadPerWarp = get_warp_size() / KThreadPerBlock;
103 constexpr index_t NumWarps = kBlockSize / get_warp_size();
104 constexpr index_t MPerThread = kMPerBlock / (NumWarps * MThreadPerWarp);
105
113 sequence<0, 1>>{});
114 }
115
116 template <typename Problem>
118 {
119 using DataType = typename Problem::KDataType;
120
121 if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
122 {
124 return 16 / sizeof(DataType);
125 }
126 else
127 {
128 return 16 / sizeof(DataType);
129 }
130 }
131
132 template <typename Problem>
134 {
135 static_assert(Problem::RotaryEnum != RotaryEmbeddingEnum::NONE);
136
137 if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
138 {
139 constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
140 constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
141 index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
142
143 return make_tuple(start_pos, start_pos + KPerThread);
144 }
145 else
146 {
147 constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
148 constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
149 index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
150
151 return make_tuple(start_pos, start_pos + KPerThread);
152 }
153 }
154
155 template <typename Problem>
157 {
158 constexpr index_t kBlockSize = Problem::kBlockSize;
159 constexpr index_t kNPerBlock = Problem::kN0;
160 constexpr index_t kKPerBlock = Problem::kK0;
161
162 constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
163 constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
164 constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
165 constexpr index_t NumWarps = kBlockSize / get_warp_size();
166 constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
167
175 sequence<0, 1>>{});
176 }
177
178 template <typename Problem>
179 CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
180 {
181 // TODO: this is for 3d layout
183 return 16 / sizeof(VDataType);
184 }
185
186 template <typename Problem>
188 {
191
192 constexpr index_t kBlockSize = Problem::kBlockSize;
193 constexpr index_t kNPerBlock = Problem::kN1;
194 constexpr index_t kKPerBlock = Problem::kN0;
195
196 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
197 {
198
199 constexpr index_t NPerThread = 16 / sizeof(VDataType);
200 constexpr index_t NThreadPerBlock = kNPerBlock / NPerThread;
201 constexpr index_t KThreadPerWarp = get_warp_size() / NThreadPerBlock;
202 constexpr index_t NumWarps = kBlockSize / get_warp_size();
203 constexpr index_t KPerThread = kKPerBlock / (NumWarps * KThreadPerWarp);
204
212 sequence<1, 0>>{});
213 }
214 else
215 {
216 constexpr index_t KPerThread = 16 / sizeof(VDataType);
217 constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
218 constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
219 constexpr index_t NumWarps = kBlockSize / get_warp_size();
220 constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
221
229 sequence<0, 1>>{});
230 }
231 }
232
233 template <typename Problem, bool IsRotaryCosSinForQ>
235 {
236 constexpr index_t height = (IsRotaryCosSinForQ ? Problem::kM0 : Problem::kN0);
237
238 if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
239 {
241 }
242 else
243 {
244 return make_tuple(number<height>{}, number<Problem::kK0 / 2>{});
245 }
246 }
247
248 template <typename Problem, bool IsRotaryCosSinForQ>
250 {
251 using DataType = std::conditional_t<IsRotaryCosSinForQ,
252 typename Problem::QDataType,
253 typename Problem::KDataType>;
254
256
257 constexpr index_t kBlockSize = Problem::kBlockSize;
258 constexpr index_t kNPerBlock = TileSize[number<0>{}];
259 constexpr index_t kKPerBlock = TileSize[number<1>{}];
260
261 constexpr index_t KPerThread = []() {
262 if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
263 {
265 return 16 / sizeof(DataType);
266 }
267 else
268 {
269 return 8 / sizeof(DataType);
270 }
271 }();
272 constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
273 constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
274 constexpr index_t NumWarps = kBlockSize / get_warp_size();
275 constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
276
284 sequence<0, 1>>{});
285 }
286};
287
288} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
@ INTERLEAVED
Definition block_rotary_embedding.hpp:14
@ HALF_ROTATED
Definition block_rotary_embedding.hpp:15
@ NONE
Definition block_rotary_embedding.hpp:13
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE index_t get_thread_id()
Definition arch.hpp:117
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:11
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentK()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:21
static CK_TILE_HOST_DEVICE constexpr auto MakeKnewDramTileDistribution()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:156
static CK_TILE_DEVICE auto GetKnewThreadRangeAlongK()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:133
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackV()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:179
static CK_TILE_HOST_DEVICE constexpr auto GetQNumElemsPerRead()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:53
static CK_TILE_HOST_DEVICE constexpr auto MakeRotaryCosSinTileDistribution()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:249
static CK_TILE_HOST_DEVICE constexpr auto GetRotaryCosSinTileSize()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:234
static CK_TILE_DEVICE auto GetQThreadRangeAlongK()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:69
static CK_TILE_HOST_DEVICE constexpr auto MakeVnewDramTileDistribution()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:187
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentQ()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:13
static CK_TILE_HOST_DEVICE constexpr auto GetKnewNumElemsPerRead()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:117
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentV()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:29
static CK_TILE_HOST_DEVICE constexpr auto MakeQDramTileDistribution()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:94
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192