block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp Source File

block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp Source File#

Composable Kernel: block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp Source File
block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12// This pipeline is qkv all located in LDS
14 : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
15 /* AsyncCopy = */ false,
16 /* NumPrefetchK = */ 1,
17 /* NumPrefetchV = */ 1>
18{
19 template <typename Problem>
20 CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
21 {
22 using GemmProblem =
23 BlockGemmProblem<typename Problem::QDataType,
24 typename Problem::KDataType,
25 typename Problem::SaccDataType,
26 Problem::kNumGemm0Warps * get_warp_size(),
27 TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
28 Problem::BlockFmhaShape::kN0,
29 Problem::BlockFmhaShape::kK0>,
30 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
31 typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
32
33 constexpr auto warp_gemm = []() {
34 if constexpr(get_warp_size() == 64 &&
35 std::is_same_v<typename Problem::QDataType, fp8_t> &&
36 std::is_same_v<typename Problem::KDataType, fp8_t> &&
37 std::is_same_v<typename Problem::SaccDataType, float>)
38 {
39 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32);
40 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32);
41 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32);
42
43 // TODO: hard coded here. Otherwise, it produces incorrect results
44 constexpr index_t swizzle_factor = 4;
46 swizzle_factor>{};
47 }
48 else
49 {
50 constexpr bool SwizzleA =
51 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32;
52 return WarpGemmDispatcher<typename Problem::QDataType,
53 typename Problem::KDataType,
54 typename Problem::SaccDataType,
55 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
56 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
57 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
58 true, // TransposeC
59 SwizzleA>{};
60 }
61 }();
62
63 using BlockGemmPolicy =
64 BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
65 typename Problem::KDataType,
66 typename Problem::SaccDataType,
67 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
68 decltype(warp_gemm)>;
69
70 if constexpr(1 < Problem::kNumGemm0Warps)
71 {
72 if constexpr(128 >= Problem::BlockFmhaShape::kK0)
74 else
76 }
77 else
79 }
80};
81
82} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t index_t
Definition integer.hpp:9
WarpGemmImpl< WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB< WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8< WGAttrCtlEnum::Default_ >, 2, swizzle_factor > > WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution
Definition warp_gemm.hpp:394
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp:18
static CK_TILE_HOST_DEVICE constexpr auto GetQKBlockGemm()
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp:20
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:266
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:16
Definition block_gemm_areg_bsmem_creg_v2_custom_policy.hpp:16
Definition block_gemm_areg_bsmem_creg_v2.hpp:16
Definition block_gemm_areg_bsmem_creg_v2r1.hpp:16
Definition block_gemm_problem.hpp:18
Definition tile_gemm_shape.hpp:17
Definition tile/core/container/sequence.hpp:49