13 template <
typename Problem>
18 typename Problem::KDataType,
19 typename Problem::AccDataType,
22 Problem::BlockFmhaShape::kN0,
23 Problem::BlockFmhaShape::kK0>,
24 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
25 typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
27 constexpr auto SwizzleA =
false;
29 typename Problem::QDataType,
30 typename Problem::KDataType,
31 typename Problem::AccDataType,
32 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<0>{}),
33 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<1>{}),
34 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<2>{}),
38 using BlockGemmPolicy =
40 typename Problem::KDataType,
41 typename Problem::AccDataType,
42 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
48 template <
typename Problem>
54 template <
typename Problem>
59 typename Problem::VDataType,
60 typename Problem::AccDataType,
63 Problem::BlockFmhaShape::kN0,
64 Problem::BlockFmhaShape::kK2>,
65 typename Problem::BlockFmhaShape::Gemm2BlockWarps,
66 typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
68 constexpr auto SwizzleA =
false;
70 typename Problem::OGradDataType,
71 typename Problem::VDataType,
72 typename Problem::AccDataType,
73 Problem::BlockFmhaShape::Gemm2WarpTile::at(
number<0>{}),
74 Problem::BlockFmhaShape::Gemm2WarpTile::at(
number<1>{}),
75 Problem::BlockFmhaShape::Gemm2WarpTile::at(
number<2>{}),
79 using BlockGemmPolicy =
81 typename Problem::VDataType,
82 typename Problem::AccDataType,
83 typename Problem::BlockFmhaShape::Gemm2BlockWarps,
89 template <
typename Problem>
95 template <
typename Problem>
98 using BlockFmhaShape =
typename Problem::BlockFmhaShape;
100 typename Problem::GemmDataType,
101 typename Problem::KDataType,
102 typename Problem::AccDataType,
106 typename BlockFmhaShape::Gemm4BlockWarps,
107 typename BlockFmhaShape::Gemm4WarpTile>>;
110 typename Problem::GemmDataType,
111 typename Problem::KDataType,
112 typename Problem::AccDataType,
113 BlockFmhaShape::Gemm4WarpTile::at(
number<0>{}),
114 BlockFmhaShape::Gemm4WarpTile::at(
number<1>{}),
115 BlockFmhaShape::Gemm4WarpTile::at(
number<2>{}),
119 (Problem::BlockFmhaShape::Gemm4WarpTile::at(
number<2>{}) == 32)
120 ? WGAttrNumAccessEnum ::Double
121 : WGAttrNumAccessEnum ::Single>;
123 using BlockGemmPolicy =
125 typename Problem::KDataType,
126 typename Problem::AccDataType,
127 typename BlockFmhaShape::Gemm4BlockWarps,
134 template <
typename Problem,
typename T>
137 return 16 /
sizeof(T);
139 template <
typename Problem>
144 template <
typename Problem>
149 template <
typename Problem>
154 template <
typename Problem>
159 template <
typename Problem>
164 template <
typename Problem>
170 template <
typename Problem>
176 template <
typename Problem>
183 template <
typename T>
186 return 8 /
sizeof(T);
188 template <
typename Problem>
194 template <
typename Problem>
200 template <
typename Problem>
203 constexpr index_t kBlockSize = Problem::kBlockSize;
204 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
205 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
207 constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
212 template <
typename Problem>
216 return 16 /
sizeof(AccDataType);
219 template <
typename Problem>
231 template <
typename T,
typename TensorView>
234 if constexpr(std::is_same_v<TensorView, ck_tile::null_tensor_view>)
240 const auto transformed_desc =
242 return tensor_view<
typename TensorView::buffer_view,
244 TensorView::DstInMemOp>{naive_view.buf_, transformed_desc};
247 template <
typename T,
typename... TD_TS>
253 constexpr auto ndims = from_desc_t::get_num_of_dimension();
254 static_assert(ndims == 2,
"XDram descriptor must have 2 dimensions");
260 constexpr index_t Dwordx4Bytes = 16;
261 constexpr index_t K2 = Dwordx4Bytes /
sizeof(T);
288 template <
typename Problem,
typename T, index_t RowsPerBlock, index_t ColsPerBlock>
291 constexpr index_t kBlockSize = Problem::kBlockSize;
296 constexpr index_t K_remain = ColsPerBlock / K2 / K3;
298 constexpr index_t K0 = K_remain / K1;
299 static_assert((K0 * K1 * K2 * K3 == ColsPerBlock) &&
301 "ColsPerBlock notdivisible");
305 constexpr index_t N0 = RowsPerBlock / N1 / N2;
306 static_assert((N0 * N1 * N2 == RowsPerBlock) && (K1 * N1 == kWarps) &&
308 "RowsPerBlock not divisible");
319 template <
typename Problem>
323 typename Problem::KDataType,
324 Problem::BlockFmhaShape::kN0,
325 Problem::BlockFmhaShape::kQKHeaddim>();
328 template <
typename Problem>
332 typename Problem::VDataType,
333 Problem::BlockFmhaShape::kN0,
334 Problem::BlockFmhaShape::kVHeaddim>();
337 template <
typename Problem>
341 typename Problem::QDataType,
342 Problem::BlockFmhaShape::kM0,
343 Problem::BlockFmhaShape::kQKHeaddim>();
346 template <
typename Problem>
350 typename Problem::OGradDataType,
351 Problem::BlockFmhaShape::kM0,
352 Problem::BlockFmhaShape::kVHeaddim>();
355 template <
typename Problem>
359 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
360 constexpr index_t MWarp = config.template at<1>();
361 constexpr index_t NWarp = config.template at<2>();
363 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
365 constexpr index_t N0 = MWarp * NWarp;
367 constexpr index_t M1 = kMPerBlock;
370 "M1 must be a factor of warp size");
381 template <
typename Problem>
387 template <
typename DataType, index_t MPerBlock, index_t KPerBlock>
390 constexpr index_t K1 = 16 /
sizeof(DataType);
391 constexpr index_t K0 = KPerBlock / K1;
394 constexpr index_t M0 = MPerBlock / M1;
405 template <
typename Problem>
410 constexpr index_t kBlockSize = Problem::kBlockSize;
411 constexpr index_t kKPerBlock = Problem::kVHeaddim;
416 template <
typename Problem>
421 constexpr index_t kBlockSize = Problem::kBlockSize;
422 constexpr index_t kKPerBlock = Problem::kVHeaddim;
427 template <
typename Problem>
432 constexpr index_t kBlockSize = Problem::kBlockSize;
433 constexpr index_t kMPerBlock = Problem::kM0;
434 constexpr index_t kKPerBlock = Problem::kQKHeaddim;
436 constexpr index_t K1 = 16 /
sizeof(AccDataType);
437 constexpr index_t K0 = kKPerBlock / K1;
441 constexpr index_t M0 = kMPerBlock / (M1 * M2);
452 template <
typename Problem>
457 constexpr index_t kBlockSize = Problem::kBlockSize;
458 constexpr index_t kMPerBlock = Problem::kM0;
459 constexpr index_t kKPerBlock = Problem::kQKHeaddim;
461 constexpr index_t K1 = 16 /
sizeof(AccDataType);
462 constexpr index_t K0 = kKPerBlock / K1;
466 constexpr index_t M0 = kMPerBlock / (M1 * M2);
477 template <
typename Problem>
483 template <
typename Problem>
489 template <
typename Problem>
493 using WarpGemm =
typename BlockGemm::WarpGemm;
495 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(
number<0>{});
496 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(
number<1>{});
498 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
499 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
501 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
502 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
513 kt_block_outer_dstr_encoding,
typename WarpGemm::BWarpDstrEncoding{});
517 decltype(kt_block_dstr_encode),
518 typename Problem::KDataType>::TransposedDstrEncode{});
523 template <
typename T, index_t MNPerBlock, index_t KPerBlock>
538 template <
typename Problem>
542 Problem::BlockFmhaShape::kN0,
543 Problem::BlockFmhaShape::kQKHeaddim>();
545 template <
typename Problem>
549 Problem::BlockFmhaShape::kN0,
550 Problem::BlockFmhaShape::kVHeaddim>();
552 template <
typename Problem>
556 Problem::BlockFmhaShape::kM0,
557 Problem::BlockFmhaShape::kQKHeaddim>();
559 template <
typename Problem>
563 Problem::BlockFmhaShape::kM0,
564 Problem::BlockFmhaShape::kQKHeaddim>();
566 template <
typename Problem>
572 template <
typename Problem,
bool Transposed = false>
577 using WarpGemm =
typename BlockGemm::WarpGemm;
579 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
580 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
582 constexpr index_t M2 = WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane;
583 constexpr index_t M1 = WarpGemm::WarpGemmAttribute::Impl::kCMLane;
584 static_assert(WarpGemm::WarpGemmAttribute::Impl::kCM0PerLane == 1,
"kCM0PerLane must be 1");
585 constexpr index_t M0 = kMPerBlock / (M1 * M2);
587 constexpr index_t N1 = WarpGemm::WarpGemmAttribute::Impl::kCNLane;
588 constexpr index_t N0 = kNPerBlock / N1;
593 constexpr index_t M1_0 = 2, M1_1 = 2;
594 constexpr index_t N1_0 = 2, N1_1 = 8;
595 static_assert(M1_0 * M1_1 == M1,
"M1_0 * M1_1 must equal M1");
596 static_assert(N1_0 * N1_1 == N1,
"N1_0 * N1_1 must equal N1");
629 constexpr auto top_dims = []() {
630 if constexpr(Transposed)
645 template <
typename T, index_t MNPerBlock, index_t KPerBlock>
648 const auto Dwordx4Bytes = 16;
649 const auto K2 = Dwordx4Bytes /
sizeof(T);
651 const auto K0 = KPerBlock / (K1 * K2);
670 template <
typename Problem>
674 Problem::BlockFmhaShape::kN0,
675 Problem::BlockFmhaShape::kQKHeaddim>();
677 template <
typename Problem>
681 Problem::BlockFmhaShape::kN0,
682 Problem::BlockFmhaShape::kVHeaddim>();
684 template <
typename Problem>
688 Problem::BlockFmhaShape::kM0,
689 Problem::BlockFmhaShape::kQKHeaddim>();
691 template <
typename Problem>
695 Problem::BlockFmhaShape::kM0,
696 Problem::BlockFmhaShape::kQKHeaddim>();
699 template <
typename Problem>
703 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
704 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
706 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(
number<0>{});
707 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(
number<1>{});
709 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
710 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
712 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
713 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
715 constexpr auto q_block_outer_dstr_encoding =
724 q_block_outer_dstr_encoding,
typename WarpGemm::AWarpDstrEncoding{});
731 template <
typename Problem>
735 using WarpGemm =
typename BlockGemm::WarpGemm;
737 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(
number<0>{});
738 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(
number<1>{});
740 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
741 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
743 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
744 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
746 constexpr auto qt_block_outer_dstr_encoding =
755 qt_block_outer_dstr_encoding,
typename WarpGemm::BWarpDstrEncoding{});
758 decltype(qt_block_dstr_encode),
759 typename Problem::QDataType>::TransposedDstrEncode{});
762 template <
typename Problem>
766 using WarpGemm =
typename BlockGemm::WarpGemm;
768 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(
number<0>{});
769 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(
number<1>{});
771 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
772 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
774 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
775 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
777 constexpr auto dst_block_outer_dstr_encoding =
786 dst_block_outer_dstr_encoding,
typename WarpGemm::AWarpDstrEncoding{});
790 return dst_block_dstr;
793 template <
typename Problem>
796 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
798 constexpr index_t kMPack = 16 /
sizeof(LSEDType);
800 constexpr auto lsed_lds_block_desc =
806 return lsed_lds_block_desc;
809 template <
typename Problem>
813 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
815 constexpr index_t MWarp = config.template at<1>();
816 constexpr index_t NWarp = config.template at<2>();
818 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
820 constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane;
824 constexpr index_t SwizzleConfig = WG::kM == 16 ? 1 : 2;
826 constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * SwizzleConfig;
827 constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane;
828 constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / SwizzleConfig;
830 constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM);
841 template <
typename Problem>
845 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
846 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
848 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(
number<0>{});
849 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(
number<1>{});
851 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
852 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
854 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
855 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
857 constexpr auto do_block_outer_dstr_encoding =
866 do_block_outer_dstr_encoding,
typename WarpGemm::AWarpDstrEncoding{});
870 return do_block_dstr;
873 template <
typename Problem>
877 using WarpGemm =
typename BlockGemm::WarpGemm;
879 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(
number<0>{});
880 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(
number<1>{});
882 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
884 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
886 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
887 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
889 constexpr auto dot_block_outer_dstr_encoding =
898 dot_block_outer_dstr_encoding,
typename WarpGemm::BWarpDstrEncoding{});
904 decltype(dot_block_dstr_encode),
905 typename Problem::OGradDataType>::TransposedDstrEncode{});
908 template <
typename Problem>
912 using WarpGemm =
typename BlockGemm::WarpGemm;
914 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(
number<0>{});
915 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(
number<1>{});
917 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
918 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
920 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
921 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
923 constexpr auto pt_block_outer_dstr_encoding =
932 pt_block_outer_dstr_encoding,
typename WarpGemm::AWarpDstrEncoding{});
936 return pt_block_dstr;
939 template <
typename Problem>
943 using WarpGemm =
typename BlockGemm::WarpGemm;
945 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(
number<0>{});
946 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(
number<1>{});
948 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
949 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK4;
951 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
952 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
954 constexpr auto ds_block_outer_dstr_encoding =
963 ds_block_outer_dstr_encoding,
typename WarpGemm::AWarpDstrEncoding{});
967 decltype(ds_block_dstr_encode),
968 typename Problem::GemmDataType>::TransposedDstrEncode{});
971 template <
typename Problem>
977 template <
typename BlockGemm>
980 using c_block_tensor_type =
decltype(BlockGemm{}.MakeCBlockTile());
981 return c_block_tensor_type::get_tile_distribution();
984 template <
typename Problem>
987 return sizeof(
typename Problem::QDataType) *
991 template <
typename Problem>
994 return sizeof(
typename Problem::KDataType) *
998 template <
typename Problem>
1003 sizeof(
typename Problem::LSEDataType) *
1007 template <
typename Problem>
1013 template <
typename Problem>
1016 return sizeof(
typename Problem::VDataType) *
1020 template <
typename Problem>
1023 return sizeof(
typename Problem::OGradDataType) *
1027 template <
typename Problem>
1030 return sizeof(
typename Problem::GemmDataType) *
1034 template <
typename Problem>
1038 return sizeof(
typename Problem::BiasDataType) *
1044 template <
typename Problem>
1056 constexpr index_t smem_size_stage0 = smem_size_k + smem_size_v;
1057 constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 +
1058 smem_size_lse * 2 + smem_size_d * 2 +
1059 max(smem_size_bias, smem_size_ds);
1060 return max(smem_size_stage0, smem_size_stage1);
1063 template <
typename Problem>
1066 static constexpr index_t kBlockSize = Problem::kBlockSize;
1067 static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0;
1068 static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
1069 static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
1070 static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
1071 static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0;
1072 static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2;
1073 static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
1075 static constexpr index_t WarpGemmM =
1076 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<0>{});
1077 static constexpr index_t WarpGemmN =
1078 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<1>{});
1079 static constexpr index_t WarpGemmK =
1080 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<2>{});
1081 static constexpr index_t Gemm4MWarp =
1082 Problem::BlockFmhaShape::Gemm4BlockWarps::at(
number<0>{});
1083 static constexpr index_t Gemm4NWarp =
1084 Problem::BlockFmhaShape::Gemm4BlockWarps::at(
number<1>{});
1087 using GemmDataType =
typename Problem::GemmDataType;
1090 static constexpr index_t Gemm0MFMA =
1091 kM0 * kN0 * kK0 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK);
1092 static constexpr index_t Gemm1MFMA =
1093 kN0 * kVHeaddim * kM0 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK);
1094 static constexpr index_t Gemm2MFMA =
1095 kM0 * kN0 * kK2 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK);
1096 static constexpr index_t Gemm3MFMA =
1097 kN0 * kQKHeaddim * kM0 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK);
1098 static constexpr index_t Gemm4MFMA =
1099 kM0 * kQKHeaddim * kN0 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK);
1102 static constexpr index_t Q_VMEM_READ =
1104 static constexpr index_t OGrad_VMEM_READ =
1106 static constexpr index_t LSE_VMEM_READ = 1;
1107 static constexpr index_t D_VMEM_READ = 1;
1109 static constexpr index_t DQ_VMEM_WRITE = kM0 * kQKHeaddim / kBlockSize;
1112 static constexpr index_t OGradT_LDS_READ =
1114 static constexpr index_t QT_LDS_READ =
1116 static constexpr index_t SGradT_LDS_READ_P1 =
1118 static constexpr index_t SGradT_LDS_READ_P2 =
1121 static constexpr index_t Q_LDS_READ =
1123 static constexpr index_t LSE_LDS_READ = kM0 / (4 * 4);
1124 static constexpr index_t D_LDS_READ = LSE_LDS_READ;
1125 static constexpr index_t OGrad_LDS_READ =
1129 static constexpr index_t Q_LDS_WRITE =
1131 static constexpr index_t QT_LDS_WRITE =
1133 static constexpr index_t OGrad_LDS_WRITE =
1135 static constexpr index_t OGradT_LDS_WRITE =
1137 static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize;
1141 Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ + DQ_VMEM_WRITE;
1147 constexpr index_t VMEM_READ_INST =
1148 Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ;
1149 constexpr index_t MFMA_INST = Gemm0MFMA;
1150 constexpr index_t LDS_READ_INST = OGradT_LDS_READ + LSE_LDS_READ + D_LDS_READ;
1152 constexpr index_t lcm_inst =
lcm(VMEM_READ_INST, MFMA_INST, LDS_READ_INST);
1154 if constexpr(i % (lcm_inst / VMEM_READ_INST) == 0)
1155 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
1156 if constexpr(i % (lcm_inst / MFMA_INST) == 0)
1157 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
1158 if constexpr(i % (lcm_inst / LDS_READ_INST) == 0)
1159 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
1167 constexpr index_t LDS_READ_INST = QT_LDS_READ;
1168 constexpr index_t MFMA_INST = Gemm1MFMA + Gemm2MFMA;
1170 constexpr index_t lcm_inst =
lcm(MFMA_INST, LDS_READ_INST);
1172 if constexpr(i % (lcm_inst / MFMA_INST) == 0)
1173 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
1174 if constexpr(i % (lcm_inst / LDS_READ_INST) == 0)
1175 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
1183 constexpr index_t LDS_WRITE_INST = SGradT_LDS_WRITE;
1184 constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ;
1185 constexpr index_t MFMA_INST = Gemm3MFMA;
1187 constexpr index_t lds_rw_inst = LDS_WRITE_INST + LDS_READ_INST;
1188 constexpr index_t lcm_inst =
lcm(MFMA_INST, lds_rw_inst);
1191 if constexpr(i % (lcm_inst / MFMA_INST) == 0)
1192 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
1193 if constexpr(i % (lcm_inst / lds_rw_inst) == 0)
1195 if constexpr(i / (lcm_inst / lds_rw_inst) < LDS_WRITE_INST)
1196 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
1198 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
1207 constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ;
1208 constexpr index_t MFMA_INST = Gemm4MFMA;
1210 constexpr index_t lcm_inst =
lcm(MFMA_INST, LDS_READ_INST);
1212 if constexpr(i % (lcm_inst / MFMA_INST) == 0)
1213 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
1214 if constexpr(i % (lcm_inst / LDS_READ_INST) == 0)
1215 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1065
static CK_TILE_DEVICE constexpr void SchedulerGemm12()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1163
static CK_TILE_DEVICE constexpr void SchedulerGemm3()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1179
static constexpr index_t TOTAL_VMEM_READ
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1140
static CK_TILE_DEVICE constexpr void SchedulerGemm4()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1203
static CK_TILE_DEVICE constexpr void SchedulerGemm0()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1143
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:371
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
TransposeTileDistributionTraits< TileDistributionEncoding_, DataType_, Policy, true > InputTileDistributionTraits
Definition load_tile_transpose.hpp:343
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1609
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Y y)
Definition tile/core/numeric/math.hpp:314
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
static CK_TILE_HOST_DEVICE constexpr auto MakeVRegBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1071
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1798
static CK_TILE_HOST_DEVICE constexpr auto GetSGradTQTBlockGemm()
Definition block_fmha_bwd_pipeline_default_policy.hpp:138
static CK_TILE_DEVICE constexpr auto GetPTOGradTBlockGemm()
Definition block_fmha_bwd_pipeline_default_policy.hpp:66
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:614
static CK_TILE_HOST_DEVICE constexpr auto MakeKRegBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1013
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledBiasTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1776
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:12
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1045
static CK_TILE_HOST_DEVICE constexpr auto MakePreOGradDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:417
static constexpr index_t WarpAlignmentBytes
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:227
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:560
static CK_TILE_HOST_DEVICE constexpr auto MakeLSEDDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:356
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentBias()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:165
static CK_TILE_HOST_DEVICE constexpr auto MakeQDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:338
static CK_TILE_HOST_DEVICE constexpr auto MakeXDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:289
static CK_TILE_HOST_DEVICE constexpr auto TransformXDramTensorView(const TensorView &naive_view)
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:232
static CK_TILE_HOST_DEVICE constexpr auto MakePreODramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:406
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeBias()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1035
static CK_TILE_HOST_DEVICE constexpr auto MakeSGradTRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:763
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentK()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:145
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentOGrad()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:160
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:567
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasSTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:978
static CK_TILE_HOST_DEVICE constexpr auto MakePostQGradDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:453
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentO()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:155
static CK_TILE_HOST_DEVICE constexpr auto GetPTOGradTBlockGemm()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:49
static CK_TILE_HOST_DEVICE constexpr auto MakeKRegBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:478
static CK_TILE_HOST_DEVICE constexpr auto MakeKLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:539
static CK_TILE_HOST_DEVICE constexpr auto MakeVDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:329
static CK_TILE_HOST_DEVICE constexpr auto MakeKLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:671
static CK_TILE_HOST_DEVICE constexpr auto GetQKBlockGemm()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:14
static CK_TILE_HOST_DEVICE constexpr auto GetTransposedAlignmentOGrad()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:195
static CK_TILE_HOST_DEVICE constexpr auto GetTransposedAlignmentBias()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:201
static CK_TILE_HOST_DEVICE constexpr auto MakePostQGradAccDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:428
static CK_TILE_HOST_DEVICE constexpr auto MakeVLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:678
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentX() noexcept
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:135
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:347
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentPostQGrad()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:220
static CK_TILE_HOST_DEVICE constexpr auto GetSGradTQTBlockGemm()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:90
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeSGrad()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1028
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:382
static CK_TILE_HOST_DEVICE constexpr auto GetTransposedAlignmentQ() noexcept
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:189
static CK_TILE_HOST_DEVICE constexpr auto GetTransposedAlignmentX() noexcept
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:184
static CK_TILE_HOST_DEVICE constexpr auto MakeKTRegBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:490
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentPostQGradAcc()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:213
static CK_TILE_HOST_DEVICE constexpr auto GetSGradKTBlockGemm()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:96
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledBiasTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:972
static CK_TILE_HOST_DEVICE constexpr auto MakeLSEDLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:810
static CK_TILE_HOST_DEVICE constexpr auto MakeQTRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:732
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentQ()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:140
static CK_TILE_HOST_DEVICE constexpr auto MakeXLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:646
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeLSE()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:999
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeK()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:992
static CK_TILE_DEVICE constexpr auto MakePTRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:909
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeQ()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:985
static CK_TILE_HOST_DEVICE constexpr auto MakeVLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:546
static CK_TILE_HOST_DEVICE constexpr auto MakeQRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:700
static CK_TILE_DEVICE constexpr auto MakeOGradTRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:874
static CK_TILE_HOST_DEVICE constexpr auto MakeLSEDLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:794
static CK_TILE_HOST_DEVICE constexpr auto MakeKDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:320
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeV()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1014
static CK_TILE_HOST_DEVICE constexpr auto MakeQLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:553
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentVGrad()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:177
static CK_TILE_HOST_DEVICE constexpr auto MakeSGradRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:940
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentKGrad()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:171
static CK_TILE_HOST_DEVICE constexpr auto MakeVRegBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:484
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeOGrad()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1021
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:692
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentV()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:150
static CK_TILE_HOST_DEVICE constexpr auto MakeSGradLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:573
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:842
static CK_TILE_HOST_DEVICE constexpr auto MakeXLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:524
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeD()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1008
static CK_TILE_HOST_DEVICE constexpr auto TransformXDramDescriptor(const tensor_descriptor< TD_TS... > &from_desc)
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:249
static CK_TILE_HOST_DEVICE constexpr auto GetOGradVBlockGemm()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:55
static CK_TILE_HOST_DEVICE constexpr auto MakeQLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:685
static CK_TILE_HOST_DEVICE constexpr auto MakePreXDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:388
Definition block_gemm_areg_breg_creg_v1_custom_policy.hpp:16
Definition block_gemm_areg_breg_creg_v1.hpp:18
Definition block_gemm_problem.hpp:18
Definition tile_gemm_shape.hpp:17
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile/core/tensor/tensor_descriptor.hpp:34
CK_TILE_HOST_DEVICE constexpr auto get_length(number< IDim > idim) const
Definition tile/core/tensor/tensor_descriptor.hpp:86
Definition tensor_view.hpp:41
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192